//! # baracuda-kernels-sys
//!
//! Raw `extern "C"` entry points for compiled bespoke kernels.
//! **You almost certainly want [`baracuda-kernels`] instead** — that
//! crate wraps these unsafe calls with typed plans, lifetime-checked
//! device buffers, and a proper Rust API.
//!
//! Functions in this crate take raw `void*` pointers, integer
//! dimensions, and a `cudaStream_t` cast as `*mut c_void`. They are
//! unsafe because:
//!
//! - They dereference the pointer arguments without bounds-checking.
//! - They assume the pointers are valid device addresses.
//! - They assume the workspace pointer (when non-null) points to at
//! least `workspace_bytes` of writable device memory.
//! - They assume the stream is a valid CUDA stream owned by the calling
//! thread's current context.
//!
//! ## Status codes
//!
//! All `*_run` and `*_can_implement` functions return an [`i32`] status:
//! - `0`: success.
//! - `1`: misaligned operand.
//! - `2`: invalid problem (e.g. M, N, or K is non-positive).
//! - `3`: not supported (this kernel doesn't implement the requested shape).
//! - `4`: workspace too small or null when required.
//! - `5`: internal kernel error (typically a launch failure).
//!
//! [`baracuda-kernels`]: https://docs.rs/baracuda-kernels
#![no_std]
use core::ffi::c_void;
// ============================================================================
// int8 GEMM — RRR layout, sm_80 (Phase 1)
// ============================================================================
//
// Layout convention `RRR`:
// A: row-major [M, K] leading dimension `lda` along K
// B: row-major [K, N] leading dimension `ldb` along N
// C: row-major [M, N] leading dimension `ldc` along N (optional;
// pass null + beta = 0 to skip)
// D: row-major [M, N] leading dimension `ldd` along N (always written)
//
// Accumulator: int32 via `mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32`.
// Epilogue: f32 alpha/beta on the int32 accum → saturating cast to s8 on
// store. Identity epilogue only in this SKU; bias variants follow in
// `gemm_s8_rrr_sm80_bias.cu` (later session).
#[cfg(feature = "sm80")]
unsafe extern "C" {
/// `S8` GEMM, RRR layout, Identity epilogue, sm_80.
///
/// # Safety
/// All pointer args must be device-resident (or null where allowed) and
/// remain valid for the duration of the launch. `stream` must be a live
/// CUDA stream in the current context.
pub fn baracuda_kernels_gemm_s8_rrr_sm80_run(
m: i32,
n: i32,
k: i32,
a: *const c_void,
lda: i64,
b: *const c_void,
ldb: i64,
c: *const c_void,
ldc: i64,
d: *mut c_void,
ldd: i64,
alpha: f32,
beta: f32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Workspace size in bytes for the `S8` RRR sm_80 Identity SKU at
/// the given problem size. Always returns zero today; reserved for
/// future SKUs that need scratch.
pub fn baracuda_kernels_gemm_s8_rrr_sm80_workspace_size(
m: i32,
n: i32,
k: i32,
) -> usize;
/// Pre-launch implementability check for the `S8` RRR sm_80
/// Identity SKU.
///
/// Returns `0` when the kernel can launch with the given shape and
/// leading dimensions; non-zero with the standard status-code
/// mapping otherwise. Does not launch a kernel and does not require
/// a stream.
///
/// # Safety
/// Same pointer-validity contract as the corresponding `*_run`
/// function, but no device dereferences occur — only host-side
/// shape checks.
pub fn baracuda_kernels_gemm_s8_rrr_sm80_can_implement(
m: i32,
n: i32,
k: i32,
a: *const c_void,
lda: i64,
b: *const c_void,
ldb: i64,
c: *const c_void,
ldc: i64,
d: *const c_void,
ldd: i64,
) -> i32;
/// `U8` GEMM, RRR layout, Identity epilogue, sm_80.
///
/// Identical shape to the S8 variant; differs only in the
/// MMA operand encoding (`.u8.u8`) and the saturating cast back
/// to `u8` on store.
///
/// # Safety
/// Same pointer-validity contract as the S8 entry point.
pub fn baracuda_kernels_gemm_u8_rrr_sm80_run(
m: i32,
n: i32,
k: i32,
a: *const c_void,
lda: i64,
b: *const c_void,
ldb: i64,
c: *const c_void,
ldc: i64,
d: *mut c_void,
ldd: i64,
alpha: f32,
beta: f32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_u8_rrr_sm80_workspace_size` (baracuda kernels gemm u8 rrr sm80 workspace size).
pub fn baracuda_kernels_gemm_u8_rrr_sm80_workspace_size(
m: i32,
n: i32,
k: i32,
) -> usize;
/// `baracuda_kernels_gemm_u8_rrr_sm80_can_implement` (baracuda kernels gemm u8 rrr sm80 can implement).
pub fn baracuda_kernels_gemm_u8_rrr_sm80_can_implement(
m: i32,
n: i32,
k: i32,
a: *const c_void,
lda: i64,
b: *const c_void,
ldb: i64,
c: *const c_void,
ldc: i64,
d: *const c_void,
ldd: i64,
) -> i32;
}
// ============================================================================
// int8 GEMM — RRR layout, sm_80, bias + activation epilogue family
// ============================================================================
//
// Eight launchers per element type: `{Bias, BiasRelu, BiasGelu, BiasSilu}`
// × `{f32, i32}` bias. The `bias` argument is an `[N]` device pointer of
// the indicated element type; it is broadcast across rows of D.
//
// All other arguments match the Identity launcher above.
#[cfg(feature = "sm80")]
unsafe extern "C" {
// -------- S8, f32 bias --------
/// `baracuda_kernels_gemm_s8_rrr_sm80_bias_f32_run` (baracuda kernels gemm s8 rrr sm80 bias f32 run).
pub fn baracuda_kernels_gemm_s8_rrr_sm80_bias_f32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_s8_rrr_sm80_bias_f32_can_implement` (baracuda kernels gemm s8 rrr sm80 bias f32 can implement).
pub fn baracuda_kernels_gemm_s8_rrr_sm80_bias_f32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
/// `baracuda_kernels_gemm_s8_rrr_sm80_bias_relu_f32_run` (baracuda kernels gemm s8 rrr sm80 bias relu f32 run).
pub fn baracuda_kernels_gemm_s8_rrr_sm80_bias_relu_f32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_s8_rrr_sm80_bias_relu_f32_can_implement` (baracuda kernels gemm s8 rrr sm80 bias relu f32 can implement).
pub fn baracuda_kernels_gemm_s8_rrr_sm80_bias_relu_f32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
/// `baracuda_kernels_gemm_s8_rrr_sm80_bias_gelu_f32_run` (baracuda kernels gemm s8 rrr sm80 bias gelu f32 run).
pub fn baracuda_kernels_gemm_s8_rrr_sm80_bias_gelu_f32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_s8_rrr_sm80_bias_gelu_f32_can_implement` (baracuda kernels gemm s8 rrr sm80 bias gelu f32 can implement).
pub fn baracuda_kernels_gemm_s8_rrr_sm80_bias_gelu_f32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
/// `baracuda_kernels_gemm_s8_rrr_sm80_bias_silu_f32_run` (baracuda kernels gemm s8 rrr sm80 bias silu f32 run).
pub fn baracuda_kernels_gemm_s8_rrr_sm80_bias_silu_f32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_s8_rrr_sm80_bias_silu_f32_can_implement` (baracuda kernels gemm s8 rrr sm80 bias silu f32 can implement).
pub fn baracuda_kernels_gemm_s8_rrr_sm80_bias_silu_f32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
// -------- S8, i32 bias --------
/// `baracuda_kernels_gemm_s8_rrr_sm80_bias_i32_run` (baracuda kernels gemm s8 rrr sm80 bias i32 run).
pub fn baracuda_kernels_gemm_s8_rrr_sm80_bias_i32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_s8_rrr_sm80_bias_i32_can_implement` (baracuda kernels gemm s8 rrr sm80 bias i32 can implement).
pub fn baracuda_kernels_gemm_s8_rrr_sm80_bias_i32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
/// `baracuda_kernels_gemm_s8_rrr_sm80_bias_relu_i32_run` (baracuda kernels gemm s8 rrr sm80 bias relu i32 run).
pub fn baracuda_kernels_gemm_s8_rrr_sm80_bias_relu_i32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_s8_rrr_sm80_bias_relu_i32_can_implement` (baracuda kernels gemm s8 rrr sm80 bias relu i32 can implement).
pub fn baracuda_kernels_gemm_s8_rrr_sm80_bias_relu_i32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
/// `baracuda_kernels_gemm_s8_rrr_sm80_bias_gelu_i32_run` (baracuda kernels gemm s8 rrr sm80 bias gelu i32 run).
pub fn baracuda_kernels_gemm_s8_rrr_sm80_bias_gelu_i32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_s8_rrr_sm80_bias_gelu_i32_can_implement` (baracuda kernels gemm s8 rrr sm80 bias gelu i32 can implement).
pub fn baracuda_kernels_gemm_s8_rrr_sm80_bias_gelu_i32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
/// `baracuda_kernels_gemm_s8_rrr_sm80_bias_silu_i32_run` (baracuda kernels gemm s8 rrr sm80 bias silu i32 run).
pub fn baracuda_kernels_gemm_s8_rrr_sm80_bias_silu_i32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_s8_rrr_sm80_bias_silu_i32_can_implement` (baracuda kernels gemm s8 rrr sm80 bias silu i32 can implement).
pub fn baracuda_kernels_gemm_s8_rrr_sm80_bias_silu_i32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
// -------- U8, f32 bias --------
/// `baracuda_kernels_gemm_u8_rrr_sm80_bias_f32_run` (baracuda kernels gemm u8 rrr sm80 bias f32 run).
pub fn baracuda_kernels_gemm_u8_rrr_sm80_bias_f32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_u8_rrr_sm80_bias_f32_can_implement` (baracuda kernels gemm u8 rrr sm80 bias f32 can implement).
pub fn baracuda_kernels_gemm_u8_rrr_sm80_bias_f32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
/// `baracuda_kernels_gemm_u8_rrr_sm80_bias_relu_f32_run` (baracuda kernels gemm u8 rrr sm80 bias relu f32 run).
pub fn baracuda_kernels_gemm_u8_rrr_sm80_bias_relu_f32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_u8_rrr_sm80_bias_relu_f32_can_implement` (baracuda kernels gemm u8 rrr sm80 bias relu f32 can implement).
pub fn baracuda_kernels_gemm_u8_rrr_sm80_bias_relu_f32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
/// `baracuda_kernels_gemm_u8_rrr_sm80_bias_gelu_f32_run` (baracuda kernels gemm u8 rrr sm80 bias gelu f32 run).
pub fn baracuda_kernels_gemm_u8_rrr_sm80_bias_gelu_f32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_u8_rrr_sm80_bias_gelu_f32_can_implement` (baracuda kernels gemm u8 rrr sm80 bias gelu f32 can implement).
pub fn baracuda_kernels_gemm_u8_rrr_sm80_bias_gelu_f32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
/// `baracuda_kernels_gemm_u8_rrr_sm80_bias_silu_f32_run` (baracuda kernels gemm u8 rrr sm80 bias silu f32 run).
pub fn baracuda_kernels_gemm_u8_rrr_sm80_bias_silu_f32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_u8_rrr_sm80_bias_silu_f32_can_implement` (baracuda kernels gemm u8 rrr sm80 bias silu f32 can implement).
pub fn baracuda_kernels_gemm_u8_rrr_sm80_bias_silu_f32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
// -------- U8, i32 bias --------
/// `baracuda_kernels_gemm_u8_rrr_sm80_bias_i32_run` (baracuda kernels gemm u8 rrr sm80 bias i32 run).
pub fn baracuda_kernels_gemm_u8_rrr_sm80_bias_i32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_u8_rrr_sm80_bias_i32_can_implement` (baracuda kernels gemm u8 rrr sm80 bias i32 can implement).
pub fn baracuda_kernels_gemm_u8_rrr_sm80_bias_i32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
/// `baracuda_kernels_gemm_u8_rrr_sm80_bias_relu_i32_run` (baracuda kernels gemm u8 rrr sm80 bias relu i32 run).
pub fn baracuda_kernels_gemm_u8_rrr_sm80_bias_relu_i32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_u8_rrr_sm80_bias_relu_i32_can_implement` (baracuda kernels gemm u8 rrr sm80 bias relu i32 can implement).
pub fn baracuda_kernels_gemm_u8_rrr_sm80_bias_relu_i32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
/// `baracuda_kernels_gemm_u8_rrr_sm80_bias_gelu_i32_run` (baracuda kernels gemm u8 rrr sm80 bias gelu i32 run).
pub fn baracuda_kernels_gemm_u8_rrr_sm80_bias_gelu_i32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_u8_rrr_sm80_bias_gelu_i32_can_implement` (baracuda kernels gemm u8 rrr sm80 bias gelu i32 can implement).
pub fn baracuda_kernels_gemm_u8_rrr_sm80_bias_gelu_i32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
/// `baracuda_kernels_gemm_u8_rrr_sm80_bias_silu_i32_run` (baracuda kernels gemm u8 rrr sm80 bias silu i32 run).
pub fn baracuda_kernels_gemm_u8_rrr_sm80_bias_silu_i32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_u8_rrr_sm80_bias_silu_i32_can_implement` (baracuda kernels gemm u8 rrr sm80 bias silu i32 can implement).
pub fn baracuda_kernels_gemm_u8_rrr_sm80_bias_silu_i32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
}
// ============================================================================
// FP8 GEMM — sm_89, full 20-SKU matrix
// ============================================================================
//
// SKU matrix: {E4M3, E5M2} × {RCR, RRR} × {Identity, Bias, BiasRelu,
// BiasGelu, BiasSilu} = 20 SKUs.
//
// Layout conventions:
//
// `RCR`:
// A: row-major [M, K] leading dimension `lda` along K
// B: col-major [K, N] leading dimension `ldb` along K
// C: row-major [M, N] leading dimension `ldc` along N (optional)
// D: row-major [M, N] leading dimension `ldd` along N
//
// `RRR`:
// A: row-major [M, K] leading dimension `lda` along K
// B: row-major [K, N] leading dimension `ldb` along N
// C: row-major [M, N] leading dimension `ldc` along N (optional)
// D: row-major [M, N] leading dimension `ldd` along N
//
// Tensor-core path: `mma.sync.aligned.m16n8k32.row.col.f32.{e4m3|e5m2}.
// {e4m3|e5m2}.f32`. Accumulator is F32; the epilogue casts to the
// output FP8 encoding with NVIDIA's `__NV_SATFINITE` semantics
// (round-half-to-even, clamp |x| to E4M3 max-finite 448.0 / E5M2
// max-finite 57344.0).
//
// Identity SKUs ship 3 fns each (`_run`, `_workspace_size`,
// `_can_implement`); bias-family SKUs share the Identity SKU's
// workspace_size + can_implement (their kernel shape is identical),
// so they ship only the `_run` fn and take an extra `bias` argument.
//
// Status codes are shared with the int-GEMM entry points (see
// crate-level doc).
#[cfg(feature = "sm89")]
unsafe extern "C" {
// -------- Identity: E4M3 × RCR (Phase 2 trailblazer) --------
/// FP8 E4M3 GEMM, RCR layout, Identity epilogue, sm_89.
///
/// # Safety
/// All pointer args must be device-resident (or null where allowed) and
/// remain valid for the duration of the launch. `stream` must be a live
/// CUDA stream in the current context. Operand bytes are interpreted
/// as E4M3 storage (`__nv_fp8_storage_t`); no host-side validation is
/// performed.
pub fn baracuda_kernels_gemm_fp8_e4m3_rcr_sm89_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e4m3_rcr_sm89_workspace_size` (baracuda kernels gemm fp8 e4m3 rcr sm89 workspace size).
pub fn baracuda_kernels_gemm_fp8_e4m3_rcr_sm89_workspace_size(
m: i32, n: i32, k: i32,
) -> usize;
/// `baracuda_kernels_gemm_fp8_e4m3_rcr_sm89_can_implement` (baracuda kernels gemm fp8 e4m3 rcr sm89 can implement).
pub fn baracuda_kernels_gemm_fp8_e4m3_rcr_sm89_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
// -------- Identity: E4M3 × RRR --------
/// FP8 E4M3 GEMM, RRR layout, Identity epilogue, sm_89.
pub fn baracuda_kernels_gemm_fp8_e4m3_rrr_sm89_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e4m3_rrr_sm89_workspace_size` (baracuda kernels gemm fp8 e4m3 rrr sm89 workspace size).
pub fn baracuda_kernels_gemm_fp8_e4m3_rrr_sm89_workspace_size(
m: i32, n: i32, k: i32,
) -> usize;
/// `baracuda_kernels_gemm_fp8_e4m3_rrr_sm89_can_implement` (baracuda kernels gemm fp8 e4m3 rrr sm89 can implement).
pub fn baracuda_kernels_gemm_fp8_e4m3_rrr_sm89_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
// -------- Identity: E5M2 × RCR --------
/// FP8 E5M2 GEMM, RCR layout, Identity epilogue, sm_89.
pub fn baracuda_kernels_gemm_fp8_e5m2_rcr_sm89_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e5m2_rcr_sm89_workspace_size` (baracuda kernels gemm fp8 e5m2 rcr sm89 workspace size).
pub fn baracuda_kernels_gemm_fp8_e5m2_rcr_sm89_workspace_size(
m: i32, n: i32, k: i32,
) -> usize;
/// `baracuda_kernels_gemm_fp8_e5m2_rcr_sm89_can_implement` (baracuda kernels gemm fp8 e5m2 rcr sm89 can implement).
pub fn baracuda_kernels_gemm_fp8_e5m2_rcr_sm89_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
// -------- Identity: E5M2 × RRR --------
/// FP8 E5M2 GEMM, RRR layout, Identity epilogue, sm_89.
pub fn baracuda_kernels_gemm_fp8_e5m2_rrr_sm89_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e5m2_rrr_sm89_workspace_size` (baracuda kernels gemm fp8 e5m2 rrr sm89 workspace size).
pub fn baracuda_kernels_gemm_fp8_e5m2_rrr_sm89_workspace_size(
m: i32, n: i32, k: i32,
) -> usize;
/// `baracuda_kernels_gemm_fp8_e5m2_rrr_sm89_can_implement` (baracuda kernels gemm fp8 e5m2 rrr sm89 can implement).
pub fn baracuda_kernels_gemm_fp8_e5m2_rrr_sm89_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
// -------- Bias family: E4M3 × RCR --------
/// `baracuda_kernels_gemm_fp8_e4m3_rcr_sm89_bias_run` (baracuda kernels gemm fp8 e4m3 rcr sm89 bias run).
pub fn baracuda_kernels_gemm_fp8_e4m3_rcr_sm89_bias_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e4m3_rcr_sm89_bias_can_implement` (baracuda kernels gemm fp8 e4m3 rcr sm89 bias can implement).
pub fn baracuda_kernels_gemm_fp8_e4m3_rcr_sm89_bias_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e4m3_rcr_sm89_bias_relu_run` (baracuda kernels gemm fp8 e4m3 rcr sm89 bias relu run).
pub fn baracuda_kernels_gemm_fp8_e4m3_rcr_sm89_bias_relu_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e4m3_rcr_sm89_bias_relu_can_implement` (baracuda kernels gemm fp8 e4m3 rcr sm89 bias relu can implement).
pub fn baracuda_kernels_gemm_fp8_e4m3_rcr_sm89_bias_relu_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e4m3_rcr_sm89_bias_gelu_run` (baracuda kernels gemm fp8 e4m3 rcr sm89 bias gelu run).
pub fn baracuda_kernels_gemm_fp8_e4m3_rcr_sm89_bias_gelu_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e4m3_rcr_sm89_bias_gelu_can_implement` (baracuda kernels gemm fp8 e4m3 rcr sm89 bias gelu can implement).
pub fn baracuda_kernels_gemm_fp8_e4m3_rcr_sm89_bias_gelu_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e4m3_rcr_sm89_bias_silu_run` (baracuda kernels gemm fp8 e4m3 rcr sm89 bias silu run).
pub fn baracuda_kernels_gemm_fp8_e4m3_rcr_sm89_bias_silu_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e4m3_rcr_sm89_bias_silu_can_implement` (baracuda kernels gemm fp8 e4m3 rcr sm89 bias silu can implement).
pub fn baracuda_kernels_gemm_fp8_e4m3_rcr_sm89_bias_silu_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
// -------- Bias family: E4M3 × RRR --------
/// `baracuda_kernels_gemm_fp8_e4m3_rrr_sm89_bias_run` (baracuda kernels gemm fp8 e4m3 rrr sm89 bias run).
pub fn baracuda_kernels_gemm_fp8_e4m3_rrr_sm89_bias_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e4m3_rrr_sm89_bias_can_implement` (baracuda kernels gemm fp8 e4m3 rrr sm89 bias can implement).
pub fn baracuda_kernels_gemm_fp8_e4m3_rrr_sm89_bias_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e4m3_rrr_sm89_bias_relu_run` (baracuda kernels gemm fp8 e4m3 rrr sm89 bias relu run).
pub fn baracuda_kernels_gemm_fp8_e4m3_rrr_sm89_bias_relu_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e4m3_rrr_sm89_bias_relu_can_implement` (baracuda kernels gemm fp8 e4m3 rrr sm89 bias relu can implement).
pub fn baracuda_kernels_gemm_fp8_e4m3_rrr_sm89_bias_relu_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e4m3_rrr_sm89_bias_gelu_run` (baracuda kernels gemm fp8 e4m3 rrr sm89 bias gelu run).
pub fn baracuda_kernels_gemm_fp8_e4m3_rrr_sm89_bias_gelu_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e4m3_rrr_sm89_bias_gelu_can_implement` (baracuda kernels gemm fp8 e4m3 rrr sm89 bias gelu can implement).
pub fn baracuda_kernels_gemm_fp8_e4m3_rrr_sm89_bias_gelu_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e4m3_rrr_sm89_bias_silu_run` (baracuda kernels gemm fp8 e4m3 rrr sm89 bias silu run).
pub fn baracuda_kernels_gemm_fp8_e4m3_rrr_sm89_bias_silu_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e4m3_rrr_sm89_bias_silu_can_implement` (baracuda kernels gemm fp8 e4m3 rrr sm89 bias silu can implement).
pub fn baracuda_kernels_gemm_fp8_e4m3_rrr_sm89_bias_silu_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
// -------- Bias family: E5M2 × RCR --------
/// `baracuda_kernels_gemm_fp8_e5m2_rcr_sm89_bias_run` (baracuda kernels gemm fp8 e5m2 rcr sm89 bias run).
pub fn baracuda_kernels_gemm_fp8_e5m2_rcr_sm89_bias_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e5m2_rcr_sm89_bias_can_implement` (baracuda kernels gemm fp8 e5m2 rcr sm89 bias can implement).
pub fn baracuda_kernels_gemm_fp8_e5m2_rcr_sm89_bias_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e5m2_rcr_sm89_bias_relu_run` (baracuda kernels gemm fp8 e5m2 rcr sm89 bias relu run).
pub fn baracuda_kernels_gemm_fp8_e5m2_rcr_sm89_bias_relu_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e5m2_rcr_sm89_bias_relu_can_implement` (baracuda kernels gemm fp8 e5m2 rcr sm89 bias relu can implement).
pub fn baracuda_kernels_gemm_fp8_e5m2_rcr_sm89_bias_relu_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e5m2_rcr_sm89_bias_gelu_run` (baracuda kernels gemm fp8 e5m2 rcr sm89 bias gelu run).
pub fn baracuda_kernels_gemm_fp8_e5m2_rcr_sm89_bias_gelu_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e5m2_rcr_sm89_bias_gelu_can_implement` (baracuda kernels gemm fp8 e5m2 rcr sm89 bias gelu can implement).
pub fn baracuda_kernels_gemm_fp8_e5m2_rcr_sm89_bias_gelu_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e5m2_rcr_sm89_bias_silu_run` (baracuda kernels gemm fp8 e5m2 rcr sm89 bias silu run).
pub fn baracuda_kernels_gemm_fp8_e5m2_rcr_sm89_bias_silu_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e5m2_rcr_sm89_bias_silu_can_implement` (baracuda kernels gemm fp8 e5m2 rcr sm89 bias silu can implement).
pub fn baracuda_kernels_gemm_fp8_e5m2_rcr_sm89_bias_silu_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
// -------- Bias family: E5M2 × RRR --------
/// `baracuda_kernels_gemm_fp8_e5m2_rrr_sm89_bias_run` (baracuda kernels gemm fp8 e5m2 rrr sm89 bias run).
pub fn baracuda_kernels_gemm_fp8_e5m2_rrr_sm89_bias_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e5m2_rrr_sm89_bias_can_implement` (baracuda kernels gemm fp8 e5m2 rrr sm89 bias can implement).
pub fn baracuda_kernels_gemm_fp8_e5m2_rrr_sm89_bias_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e5m2_rrr_sm89_bias_relu_run` (baracuda kernels gemm fp8 e5m2 rrr sm89 bias relu run).
pub fn baracuda_kernels_gemm_fp8_e5m2_rrr_sm89_bias_relu_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e5m2_rrr_sm89_bias_relu_can_implement` (baracuda kernels gemm fp8 e5m2 rrr sm89 bias relu can implement).
pub fn baracuda_kernels_gemm_fp8_e5m2_rrr_sm89_bias_relu_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e5m2_rrr_sm89_bias_gelu_run` (baracuda kernels gemm fp8 e5m2 rrr sm89 bias gelu run).
pub fn baracuda_kernels_gemm_fp8_e5m2_rrr_sm89_bias_gelu_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e5m2_rrr_sm89_bias_gelu_can_implement` (baracuda kernels gemm fp8 e5m2 rrr sm89 bias gelu can implement).
pub fn baracuda_kernels_gemm_fp8_e5m2_rrr_sm89_bias_gelu_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e5m2_rrr_sm89_bias_silu_run` (baracuda kernels gemm fp8 e5m2 rrr sm89 bias silu run).
pub fn baracuda_kernels_gemm_fp8_e5m2_rrr_sm89_bias_silu_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *mut c_void, ldd: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_fp8_e5m2_rrr_sm89_bias_silu_can_implement` (baracuda kernels gemm fp8 e5m2 rrr sm89 bias silu can implement).
pub fn baracuda_kernels_gemm_fp8_e5m2_rrr_sm89_bias_silu_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda: i64,
b: *const c_void, ldb: i64,
c: *const c_void, ldc: i64,
d: *const c_void, ldd: i64,
) -> i32;
}
// ============================================================================
// int4 GEMM — sm_89, S4 RCR Identity trailblazer
// ============================================================================
//
// Phase 2 int4 trailblazer (alpha.17). The S4 RCR Identity SKU proves
// the packed-storage path (two int4 per byte, low-nibble = even index,
// high-nibble = odd index along the K axis for A/B and along the N
// axis for D output) and the `mma.sync.aligned.m16n8k64.row.col.
// satfinite.s32.s4.s4.s32` PTX intrinsic. The U4 / RRR / bias-family
// variants follow in subsequent fanout commits.
//
// Layout convention `RCR`:
//
// A: row-major [M, K], leading dimension `lda_bytes` along K (= K/2
// storage bytes per row when there's no padding)
// B: col-major [K, N], leading dimension `ldb_bytes` along K (= K/2
// storage bytes per column when there's no padding)
// C: row-major [M, N], leading dimension `ldc_bytes` along N (= N/2
// storage bytes per row; optional — pass null
// + beta = 0 to skip)
// D: row-major [M, N], leading dimension `ldd_bytes` along N (= N/2
// storage bytes per row; always written)
//
// `M`, `N`, `K` are **element** counts; `lda_bytes` / `ldb_bytes` /
// `ldc_bytes` / `ldd_bytes` are **byte** counts (= storage-slot counts;
// the kernel walks byte arithmetic internally). Both `K` and `N` must
// be even (the packing is byte-aligned at K for A/B and at N for D);
// odd `K` or `N` returns status code 3.
//
// Tensor-core path: `mma.sync.aligned.m16n8k64.row.col.satfinite.s32.
// s4.s4.s32`. Accumulator is S32; the epilogue applies `f32 * alpha + f32
// * beta * dequant(C)` then saturating-casts back to S4 with round-
// half-to-even and clamp to `[-8, +7]`.
#[cfg(feature = "sm89")]
unsafe extern "C" {
/// S4 GEMM, RCR layout, Identity epilogue, sm_89.
///
/// `lda_bytes` / `ldb_bytes` / `ldc_bytes` / `ldd_bytes` are in
/// **bytes** (= packed-pair storage slot counts).
///
/// # Safety
/// All pointer args must be device-resident (or null where allowed)
/// and remain valid for the duration of the launch. `stream` must
/// be a live CUDA stream in the current context. Operand bytes are
/// interpreted as packed-pair int4 storage (low nibble = even index,
/// high nibble = odd index along the K axis for A/B and along the
/// N axis for C/D); no host-side validation is performed.
pub fn baracuda_kernels_gemm_s4_rcr_sm89_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_s4_rcr_sm89_workspace_size` (baracuda kernels gemm s4 rcr sm89 workspace size).
pub fn baracuda_kernels_gemm_s4_rcr_sm89_workspace_size(
m: i32, n: i32, k: i32,
) -> usize;
/// `baracuda_kernels_gemm_s4_rcr_sm89_can_implement` (baracuda kernels gemm s4 rcr sm89 can implement).
pub fn baracuda_kernels_gemm_s4_rcr_sm89_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// U4 GEMM, RCR layout, Identity epilogue, sm_89.
///
/// Identical shape to the S4 variant; differs only in the MMA
/// operand encoding (`.u4.u4`) and the saturating cast back to u4
/// (clamp `[0, 15]`).
///
/// # Safety
/// Same pointer-validity contract as the S4 entry point. Operand
/// bytes are interpreted as packed-pair u4 storage (low nibble +
/// high nibble); no host-side validation is performed.
pub fn baracuda_kernels_gemm_u4_rcr_sm89_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_u4_rcr_sm89_workspace_size` (baracuda kernels gemm u4 rcr sm89 workspace size).
pub fn baracuda_kernels_gemm_u4_rcr_sm89_workspace_size(
m: i32, n: i32, k: i32,
) -> usize;
/// `baracuda_kernels_gemm_u4_rcr_sm89_can_implement` (baracuda kernels gemm u4 rcr sm89 can implement).
pub fn baracuda_kernels_gemm_u4_rcr_sm89_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// S4 GEMM, RRR layout, Identity epilogue, sm_89.
///
/// `B` is row-major `[K, N]` pair-packed along N. The kernel
/// gathers two nibbles from two gmem K-rows to assemble one
/// packed-pair smem byte per output column (see header comment in
/// `baracuda_int4_rrr_sm89.cuh`).
///
/// # Safety
/// Same pointer-validity contract as the S4 RCR entry point.
pub fn baracuda_kernels_gemm_s4_rrr_sm89_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_s4_rrr_sm89_workspace_size` (baracuda kernels gemm s4 rrr sm89 workspace size).
pub fn baracuda_kernels_gemm_s4_rrr_sm89_workspace_size(
m: i32, n: i32, k: i32,
) -> usize;
/// `baracuda_kernels_gemm_s4_rrr_sm89_can_implement` (baracuda kernels gemm s4 rrr sm89 can implement).
pub fn baracuda_kernels_gemm_s4_rrr_sm89_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// U4 GEMM, RRR layout, Identity epilogue, sm_89.
pub fn baracuda_kernels_gemm_u4_rrr_sm89_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_u4_rrr_sm89_workspace_size` (baracuda kernels gemm u4 rrr sm89 workspace size).
pub fn baracuda_kernels_gemm_u4_rrr_sm89_workspace_size(
m: i32, n: i32, k: i32,
) -> usize;
/// `baracuda_kernels_gemm_u4_rrr_sm89_can_implement` (baracuda kernels gemm u4 rrr sm89 can implement).
pub fn baracuda_kernels_gemm_u4_rrr_sm89_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
}
// ============================================================================
// int4 GEMM — sm_89, bias + activation epilogue family
// ============================================================================
//
// 32 launchers covering `{S4, U4} × {RCR, RRR} × {Bias, BiasRelu,
// BiasGelu, BiasSilu} × {f32 bias, i32 bias}`. The `bias` argument is
// an `[N]` device pointer of the indicated element type; it is
// broadcast across rows of D. All other arguments match the Identity
// int4 launchers above.
//
// The kernel body is identical to the Identity case — only the
// epilogue chain (bias-add → optional scalar activation → saturating
// cast back to int4) varies. `_workspace_size` and `_can_implement`
// are shared with the Identity SKU of the same `(element, layout)`
// pair (call the Identity entry points for those).
#[cfg(feature = "sm89")]
unsafe extern "C" {
// -------- S4 × RCR --------
/// `baracuda_kernels_gemm_s4_rcr_sm89_bias_f32_run` (baracuda kernels gemm s4 rcr sm89 bias f32 run).
pub fn baracuda_kernels_gemm_s4_rcr_sm89_bias_f32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_s4_rcr_sm89_bias_f32_can_implement` (baracuda kernels gemm s4 rcr sm89 bias f32 can implement).
pub fn baracuda_kernels_gemm_s4_rcr_sm89_bias_f32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// `baracuda_kernels_gemm_s4_rcr_sm89_bias_relu_f32_run` (baracuda kernels gemm s4 rcr sm89 bias relu f32 run).
pub fn baracuda_kernels_gemm_s4_rcr_sm89_bias_relu_f32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_s4_rcr_sm89_bias_relu_f32_can_implement` (baracuda kernels gemm s4 rcr sm89 bias relu f32 can implement).
pub fn baracuda_kernels_gemm_s4_rcr_sm89_bias_relu_f32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// `baracuda_kernels_gemm_s4_rcr_sm89_bias_gelu_f32_run` (baracuda kernels gemm s4 rcr sm89 bias gelu f32 run).
pub fn baracuda_kernels_gemm_s4_rcr_sm89_bias_gelu_f32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_s4_rcr_sm89_bias_gelu_f32_can_implement` (baracuda kernels gemm s4 rcr sm89 bias gelu f32 can implement).
pub fn baracuda_kernels_gemm_s4_rcr_sm89_bias_gelu_f32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// `baracuda_kernels_gemm_s4_rcr_sm89_bias_silu_f32_run` (baracuda kernels gemm s4 rcr sm89 bias silu f32 run).
pub fn baracuda_kernels_gemm_s4_rcr_sm89_bias_silu_f32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_s4_rcr_sm89_bias_silu_f32_can_implement` (baracuda kernels gemm s4 rcr sm89 bias silu f32 can implement).
pub fn baracuda_kernels_gemm_s4_rcr_sm89_bias_silu_f32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// `baracuda_kernels_gemm_s4_rcr_sm89_bias_i32_run` (baracuda kernels gemm s4 rcr sm89 bias i32 run).
pub fn baracuda_kernels_gemm_s4_rcr_sm89_bias_i32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_s4_rcr_sm89_bias_i32_can_implement` (baracuda kernels gemm s4 rcr sm89 bias i32 can implement).
pub fn baracuda_kernels_gemm_s4_rcr_sm89_bias_i32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// `baracuda_kernels_gemm_s4_rcr_sm89_bias_relu_i32_run` (baracuda kernels gemm s4 rcr sm89 bias relu i32 run).
pub fn baracuda_kernels_gemm_s4_rcr_sm89_bias_relu_i32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_s4_rcr_sm89_bias_relu_i32_can_implement` (baracuda kernels gemm s4 rcr sm89 bias relu i32 can implement).
pub fn baracuda_kernels_gemm_s4_rcr_sm89_bias_relu_i32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// `baracuda_kernels_gemm_s4_rcr_sm89_bias_gelu_i32_run` (baracuda kernels gemm s4 rcr sm89 bias gelu i32 run).
pub fn baracuda_kernels_gemm_s4_rcr_sm89_bias_gelu_i32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_s4_rcr_sm89_bias_gelu_i32_can_implement` (baracuda kernels gemm s4 rcr sm89 bias gelu i32 can implement).
pub fn baracuda_kernels_gemm_s4_rcr_sm89_bias_gelu_i32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// `baracuda_kernels_gemm_s4_rcr_sm89_bias_silu_i32_run` (baracuda kernels gemm s4 rcr sm89 bias silu i32 run).
pub fn baracuda_kernels_gemm_s4_rcr_sm89_bias_silu_i32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_s4_rcr_sm89_bias_silu_i32_can_implement` (baracuda kernels gemm s4 rcr sm89 bias silu i32 can implement).
pub fn baracuda_kernels_gemm_s4_rcr_sm89_bias_silu_i32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
// -------- U4 × RCR --------
/// `baracuda_kernels_gemm_u4_rcr_sm89_bias_f32_run` (baracuda kernels gemm u4 rcr sm89 bias f32 run).
pub fn baracuda_kernels_gemm_u4_rcr_sm89_bias_f32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_u4_rcr_sm89_bias_f32_can_implement` (baracuda kernels gemm u4 rcr sm89 bias f32 can implement).
pub fn baracuda_kernels_gemm_u4_rcr_sm89_bias_f32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// `baracuda_kernels_gemm_u4_rcr_sm89_bias_relu_f32_run` (baracuda kernels gemm u4 rcr sm89 bias relu f32 run).
pub fn baracuda_kernels_gemm_u4_rcr_sm89_bias_relu_f32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_u4_rcr_sm89_bias_relu_f32_can_implement` (baracuda kernels gemm u4 rcr sm89 bias relu f32 can implement).
pub fn baracuda_kernels_gemm_u4_rcr_sm89_bias_relu_f32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// `baracuda_kernels_gemm_u4_rcr_sm89_bias_gelu_f32_run` (baracuda kernels gemm u4 rcr sm89 bias gelu f32 run).
pub fn baracuda_kernels_gemm_u4_rcr_sm89_bias_gelu_f32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_u4_rcr_sm89_bias_gelu_f32_can_implement` (baracuda kernels gemm u4 rcr sm89 bias gelu f32 can implement).
pub fn baracuda_kernels_gemm_u4_rcr_sm89_bias_gelu_f32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// `baracuda_kernels_gemm_u4_rcr_sm89_bias_silu_f32_run` (baracuda kernels gemm u4 rcr sm89 bias silu f32 run).
pub fn baracuda_kernels_gemm_u4_rcr_sm89_bias_silu_f32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_u4_rcr_sm89_bias_silu_f32_can_implement` (baracuda kernels gemm u4 rcr sm89 bias silu f32 can implement).
pub fn baracuda_kernels_gemm_u4_rcr_sm89_bias_silu_f32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// `baracuda_kernels_gemm_u4_rcr_sm89_bias_i32_run` (baracuda kernels gemm u4 rcr sm89 bias i32 run).
pub fn baracuda_kernels_gemm_u4_rcr_sm89_bias_i32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_u4_rcr_sm89_bias_i32_can_implement` (baracuda kernels gemm u4 rcr sm89 bias i32 can implement).
pub fn baracuda_kernels_gemm_u4_rcr_sm89_bias_i32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// `baracuda_kernels_gemm_u4_rcr_sm89_bias_relu_i32_run` (baracuda kernels gemm u4 rcr sm89 bias relu i32 run).
pub fn baracuda_kernels_gemm_u4_rcr_sm89_bias_relu_i32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_u4_rcr_sm89_bias_relu_i32_can_implement` (baracuda kernels gemm u4 rcr sm89 bias relu i32 can implement).
pub fn baracuda_kernels_gemm_u4_rcr_sm89_bias_relu_i32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// `baracuda_kernels_gemm_u4_rcr_sm89_bias_gelu_i32_run` (baracuda kernels gemm u4 rcr sm89 bias gelu i32 run).
pub fn baracuda_kernels_gemm_u4_rcr_sm89_bias_gelu_i32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_u4_rcr_sm89_bias_gelu_i32_can_implement` (baracuda kernels gemm u4 rcr sm89 bias gelu i32 can implement).
pub fn baracuda_kernels_gemm_u4_rcr_sm89_bias_gelu_i32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// `baracuda_kernels_gemm_u4_rcr_sm89_bias_silu_i32_run` (baracuda kernels gemm u4 rcr sm89 bias silu i32 run).
pub fn baracuda_kernels_gemm_u4_rcr_sm89_bias_silu_i32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_u4_rcr_sm89_bias_silu_i32_can_implement` (baracuda kernels gemm u4 rcr sm89 bias silu i32 can implement).
pub fn baracuda_kernels_gemm_u4_rcr_sm89_bias_silu_i32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
// -------- S4 × RRR --------
/// `baracuda_kernels_gemm_s4_rrr_sm89_bias_f32_run` (baracuda kernels gemm s4 rrr sm89 bias f32 run).
pub fn baracuda_kernels_gemm_s4_rrr_sm89_bias_f32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_s4_rrr_sm89_bias_f32_can_implement` (baracuda kernels gemm s4 rrr sm89 bias f32 can implement).
pub fn baracuda_kernels_gemm_s4_rrr_sm89_bias_f32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// `baracuda_kernels_gemm_s4_rrr_sm89_bias_relu_f32_run` (baracuda kernels gemm s4 rrr sm89 bias relu f32 run).
pub fn baracuda_kernels_gemm_s4_rrr_sm89_bias_relu_f32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_s4_rrr_sm89_bias_relu_f32_can_implement` (baracuda kernels gemm s4 rrr sm89 bias relu f32 can implement).
pub fn baracuda_kernels_gemm_s4_rrr_sm89_bias_relu_f32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// `baracuda_kernels_gemm_s4_rrr_sm89_bias_gelu_f32_run` (baracuda kernels gemm s4 rrr sm89 bias gelu f32 run).
pub fn baracuda_kernels_gemm_s4_rrr_sm89_bias_gelu_f32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_s4_rrr_sm89_bias_gelu_f32_can_implement` (baracuda kernels gemm s4 rrr sm89 bias gelu f32 can implement).
pub fn baracuda_kernels_gemm_s4_rrr_sm89_bias_gelu_f32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// `baracuda_kernels_gemm_s4_rrr_sm89_bias_silu_f32_run` (baracuda kernels gemm s4 rrr sm89 bias silu f32 run).
pub fn baracuda_kernels_gemm_s4_rrr_sm89_bias_silu_f32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_s4_rrr_sm89_bias_silu_f32_can_implement` (baracuda kernels gemm s4 rrr sm89 bias silu f32 can implement).
pub fn baracuda_kernels_gemm_s4_rrr_sm89_bias_silu_f32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// `baracuda_kernels_gemm_s4_rrr_sm89_bias_i32_run` (baracuda kernels gemm s4 rrr sm89 bias i32 run).
pub fn baracuda_kernels_gemm_s4_rrr_sm89_bias_i32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_s4_rrr_sm89_bias_i32_can_implement` (baracuda kernels gemm s4 rrr sm89 bias i32 can implement).
pub fn baracuda_kernels_gemm_s4_rrr_sm89_bias_i32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// `baracuda_kernels_gemm_s4_rrr_sm89_bias_relu_i32_run` (baracuda kernels gemm s4 rrr sm89 bias relu i32 run).
pub fn baracuda_kernels_gemm_s4_rrr_sm89_bias_relu_i32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_s4_rrr_sm89_bias_relu_i32_can_implement` (baracuda kernels gemm s4 rrr sm89 bias relu i32 can implement).
pub fn baracuda_kernels_gemm_s4_rrr_sm89_bias_relu_i32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// `baracuda_kernels_gemm_s4_rrr_sm89_bias_gelu_i32_run` (baracuda kernels gemm s4 rrr sm89 bias gelu i32 run).
pub fn baracuda_kernels_gemm_s4_rrr_sm89_bias_gelu_i32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_s4_rrr_sm89_bias_gelu_i32_can_implement` (baracuda kernels gemm s4 rrr sm89 bias gelu i32 can implement).
pub fn baracuda_kernels_gemm_s4_rrr_sm89_bias_gelu_i32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// `baracuda_kernels_gemm_s4_rrr_sm89_bias_silu_i32_run` (baracuda kernels gemm s4 rrr sm89 bias silu i32 run).
pub fn baracuda_kernels_gemm_s4_rrr_sm89_bias_silu_i32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_s4_rrr_sm89_bias_silu_i32_can_implement` (baracuda kernels gemm s4 rrr sm89 bias silu i32 can implement).
pub fn baracuda_kernels_gemm_s4_rrr_sm89_bias_silu_i32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
// -------- U4 × RRR --------
/// `baracuda_kernels_gemm_u4_rrr_sm89_bias_f32_run` (baracuda kernels gemm u4 rrr sm89 bias f32 run).
pub fn baracuda_kernels_gemm_u4_rrr_sm89_bias_f32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_u4_rrr_sm89_bias_f32_can_implement` (baracuda kernels gemm u4 rrr sm89 bias f32 can implement).
pub fn baracuda_kernels_gemm_u4_rrr_sm89_bias_f32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// `baracuda_kernels_gemm_u4_rrr_sm89_bias_relu_f32_run` (baracuda kernels gemm u4 rrr sm89 bias relu f32 run).
pub fn baracuda_kernels_gemm_u4_rrr_sm89_bias_relu_f32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_u4_rrr_sm89_bias_relu_f32_can_implement` (baracuda kernels gemm u4 rrr sm89 bias relu f32 can implement).
pub fn baracuda_kernels_gemm_u4_rrr_sm89_bias_relu_f32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// `baracuda_kernels_gemm_u4_rrr_sm89_bias_gelu_f32_run` (baracuda kernels gemm u4 rrr sm89 bias gelu f32 run).
pub fn baracuda_kernels_gemm_u4_rrr_sm89_bias_gelu_f32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_u4_rrr_sm89_bias_gelu_f32_can_implement` (baracuda kernels gemm u4 rrr sm89 bias gelu f32 can implement).
pub fn baracuda_kernels_gemm_u4_rrr_sm89_bias_gelu_f32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// `baracuda_kernels_gemm_u4_rrr_sm89_bias_silu_f32_run` (baracuda kernels gemm u4 rrr sm89 bias silu f32 run).
pub fn baracuda_kernels_gemm_u4_rrr_sm89_bias_silu_f32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_u4_rrr_sm89_bias_silu_f32_can_implement` (baracuda kernels gemm u4 rrr sm89 bias silu f32 can implement).
pub fn baracuda_kernels_gemm_u4_rrr_sm89_bias_silu_f32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// `baracuda_kernels_gemm_u4_rrr_sm89_bias_i32_run` (baracuda kernels gemm u4 rrr sm89 bias i32 run).
pub fn baracuda_kernels_gemm_u4_rrr_sm89_bias_i32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_u4_rrr_sm89_bias_i32_can_implement` (baracuda kernels gemm u4 rrr sm89 bias i32 can implement).
pub fn baracuda_kernels_gemm_u4_rrr_sm89_bias_i32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// `baracuda_kernels_gemm_u4_rrr_sm89_bias_relu_i32_run` (baracuda kernels gemm u4 rrr sm89 bias relu i32 run).
pub fn baracuda_kernels_gemm_u4_rrr_sm89_bias_relu_i32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_u4_rrr_sm89_bias_relu_i32_can_implement` (baracuda kernels gemm u4 rrr sm89 bias relu i32 can implement).
pub fn baracuda_kernels_gemm_u4_rrr_sm89_bias_relu_i32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// `baracuda_kernels_gemm_u4_rrr_sm89_bias_gelu_i32_run` (baracuda kernels gemm u4 rrr sm89 bias gelu i32 run).
pub fn baracuda_kernels_gemm_u4_rrr_sm89_bias_gelu_i32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_u4_rrr_sm89_bias_gelu_i32_can_implement` (baracuda kernels gemm u4 rrr sm89 bias gelu i32 can implement).
pub fn baracuda_kernels_gemm_u4_rrr_sm89_bias_gelu_i32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
/// `baracuda_kernels_gemm_u4_rrr_sm89_bias_silu_i32_run` (baracuda kernels gemm u4 rrr sm89 bias silu i32 run).
pub fn baracuda_kernels_gemm_u4_rrr_sm89_bias_silu_i32_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *mut c_void, ldd_bytes: i64,
bias: *const c_void,
alpha: f32, beta: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_u4_rrr_sm89_bias_silu_i32_can_implement` (baracuda kernels gemm u4 rrr sm89 bias silu i32 can implement).
pub fn baracuda_kernels_gemm_u4_rrr_sm89_bias_silu_i32_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
c: *const c_void, ldc_bytes: i64,
d: *const c_void, ldd_bytes: i64,
) -> i32;
}
// ============================================================================
// Binary (B1) GEMM — sm_89 (Identity-only, RCR layout)
// ============================================================================
//
// Distinct programming model: `D[i,j] = sum_k popcount(A[i, k_byte] XOR
// B[k_byte, j])` (raw int32 accumulator, no re-quantization back to b1
// and no α/β/bias/activation chain).
//
// Layout convention `RCR`:
//
// A: row-major [M, K bits], leading dimension `lda_bytes` along K
// (= K/8 storage bytes per row)
// B: col-major [K, N bits], leading dimension `ldb_bytes` along K
// (= K/8 storage bytes per column)
// D: row-major [M, N i32], leading dimension `ldd_elements` along N
// (int32 element count, NOT bytes — D is a
// plain int32 matrix with no packing)
//
// `M`, `N`, `K` are **element** counts; `lda_bytes` / `ldb_bytes` are
// **byte** counts; `ldd_elements` is in **i32 element** count. `K` must
// be divisible by 8 (packing is byte-aligned). No constraint on N
// (output is plain int32).
//
// Tensor-core path:
// `mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc`.
#[cfg(feature = "sm89")]
unsafe extern "C" {
/// Binary (B1) GEMM, RCR layout, Identity epilogue, sm_89.
///
/// `ldd_elements` is in **i32 element count**, not bytes — the D
/// output is a plain `int32_t[M, N]` matrix with no packing. A/B
/// `ld` values are in bytes (= packed-bit storage slots).
///
/// # Safety
/// All pointer args must be device-resident and remain valid for the
/// duration of the launch. `stream` must be a live CUDA stream in
/// the current context. Operand bytes are interpreted as packed-bit
/// B1 storage (LSB = lowest K index within each byte); no host-side
/// validation is performed. The `d` buffer must hold at least
/// `M * ldd_elements` `int32_t` elements.
pub fn baracuda_kernels_gemm_bin_rcr_sm89_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
d: *mut c_void, ldd_elements: i64,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_bin_rcr_sm89_workspace_size` (baracuda kernels gemm bin rcr sm89 workspace size).
pub fn baracuda_kernels_gemm_bin_rcr_sm89_workspace_size(
m: i32, n: i32, k: i32,
) -> usize;
/// `baracuda_kernels_gemm_bin_rcr_sm89_can_implement` (baracuda kernels gemm bin rcr sm89 can implement).
pub fn baracuda_kernels_gemm_bin_rcr_sm89_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
d: *const c_void, ldd_elements: i64,
) -> i32;
/// Binary (B1) GEMM, RRR layout, Identity epilogue, sm_89.
///
/// Distinct from the RCR variant in that `B` is row-major and
/// bit-packed along N in gmem (the kernel re-packs into K-bit-
/// packed smem via a bit-gather load). Same int32 output
/// convention as RCR — `ldd_elements` is in i32 element count.
///
/// Requires both `K` and `N` to be divisible by 8.
///
/// # Safety
/// Same pointer-validity contract as the RCR entry point.
pub fn baracuda_kernels_gemm_bin_rrr_sm89_run(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
d: *mut c_void, ldd_elements: i64,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gemm_bin_rrr_sm89_workspace_size` (baracuda kernels gemm bin rrr sm89 workspace size).
pub fn baracuda_kernels_gemm_bin_rrr_sm89_workspace_size(
m: i32, n: i32, k: i32,
) -> usize;
/// `baracuda_kernels_gemm_bin_rrr_sm89_can_implement` (baracuda kernels gemm bin rrr sm89 can implement).
pub fn baracuda_kernels_gemm_bin_rrr_sm89_can_implement(
m: i32, n: i32, k: i32,
a: *const c_void, lda_bytes: i64,
b: *const c_void, ldb_bytes: i64,
d: *const c_void, ldd_elements: i64,
) -> i32;
}
// ============================================================================
// Elementwise — Phase 3 trailblazer (binary add, FP family)
// ============================================================================
//
// Contiguous pointwise binary kernels. Inputs / output are arbitrary-
// rank tensors flattened to a single `numel` element count on the FFI
// boundary (the Rust plan layer collapses contiguous shapes for the
// "all-contig 1D sweep" fast path).
//
// ABI:
// numel — i64 element count (product of shape).
// a / b — input device pointers (T const*).
// y — output device pointer (T*). Aliasing with
// either input is fine for the all-contig
// fast path (the kernel reads each i once
// before writing each i once).
// workspace / bytes — unused for elementwise; pass null + 0 from
// Rust. Carried in the signature for ABI
// parity with the GEMM family.
// stream — cudaStream_t cast to `*mut c_void`.
//
// Status codes are shared with the GEMM family (see crate-level doc).
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Binary elementwise `add`, f32 dtype, contiguous fast path. This
/// is the binary-pointwise trailblazer — its safety contract carries
/// over to every other binary contig launcher (`add`, `sub`, `mul`,
/// `div`, `min`, `max`, `pow`, comparison ops, etc.) across all
/// dtypes.
///
/// # Safety
/// All pointer args must be device-resident and remain valid for the
/// duration of the launch. `stream` must be a live CUDA stream in
/// the current context. `a`, `b`, and `y` must each point to at
/// least `numel` `float`s of device memory.
///
/// **Aliasing**: aliasing `y` with either input (`a == y` or
/// `b == y`, or both) is safe. The contig kernel evaluates
/// `y[i] = op(a[i], b[i])` with each thread touching only its own
/// index `i` (read `a[i]` + `b[i]` before write `y[i]`), so callers
/// implementing in-place binary ops (e.g. Fuel's `Op::AddInplace`,
/// `Op::MulInplace`) can dispatch the forward symbol with
/// `a_ptr == y_ptr` or `b_ptr == y_ptr` without a dedicated
/// `_inplace_` variant. The `__restrict__` qualifiers on the kernel
/// signature are an optimizer hint; the per-thread access pattern
/// makes aliasing structurally safe regardless. This contract is
/// stable across baracuda versions.
pub fn baracuda_kernels_binary_add_f32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_add_f32`. Validates
/// the problem size without launching a kernel. Returns the standard
/// status code mapping.
///
/// # Safety
/// Same pointer-validity contract as the corresponding `_run` fn,
/// but no device dereferences occur — only host-side checks.
pub fn baracuda_kernels_binary_add_f32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `sub`, f32 dtype, contiguous fast path.
///
/// # Safety
/// All pointer args must be device-resident and remain valid for the
/// duration of the launch. `stream` must be a live CUDA stream in
/// the current context. `a`, `b`, and `y` must each point to at
/// least `numel` `float`s of device memory.
pub fn baracuda_kernels_binary_sub_f32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_sub_f32`. Validates
/// the problem size without launching a kernel. Returns the standard
/// status code mapping.
///
/// # Safety
/// Same pointer-validity contract as the corresponding `_run` fn,
/// but no device dereferences occur — only host-side checks.
pub fn baracuda_kernels_binary_sub_f32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `mul`, f32 dtype, contiguous fast path.
///
/// # Safety
/// All pointer args must be device-resident and remain valid for the
/// duration of the launch. `stream` must be a live CUDA stream in
/// the current context. `a`, `b`, and `y` must each point to at
/// least `numel` `float`s of device memory.
pub fn baracuda_kernels_binary_mul_f32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_mul_f32`. Validates
/// the problem size without launching a kernel. Returns the standard
/// status code mapping.
///
/// # Safety
/// Same pointer-validity contract as the corresponding `_run` fn,
/// but no device dereferences occur — only host-side checks.
pub fn baracuda_kernels_binary_mul_f32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `div`, f32 dtype, contiguous fast path.
///
/// # Safety
/// All pointer args must be device-resident and remain valid for the
/// duration of the launch. `stream` must be a live CUDA stream in
/// the current context. `a`, `b`, and `y` must each point to at
/// least `numel` `float`s of device memory.
pub fn baracuda_kernels_binary_div_f32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_div_f32`. Validates
/// the problem size without launching a kernel. Returns the standard
/// status code mapping.
///
/// # Safety
/// Same pointer-validity contract as the corresponding `_run` fn,
/// but no device dereferences occur — only host-side checks.
pub fn baracuda_kernels_binary_div_f32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
// ------------------------------------------------------------------
// dtype fanout — f16 / bf16 / f64 siblings of the four f32 launchers
// above. ABI is identical to the f32 variants (the dtype is encoded
// only in the symbol name); see those decls for ABI contract docs.
// ------------------------------------------------------------------
/// Binary elementwise `add`, f16 dtype, contiguous fast path.
///
/// # Safety
/// All pointer args must be device-resident and remain valid for the
/// duration of the launch. `stream` must be a live CUDA stream in
/// the current context. `a`, `b`, and `y` must each point to at
/// least `numel` `__half`s of device memory.
pub fn baracuda_kernels_binary_add_f16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_add_f16`.
///
/// # Safety
/// Same pointer-validity contract as the corresponding `_run` fn,
/// but no device dereferences occur — only host-side checks.
pub fn baracuda_kernels_binary_add_f16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `add`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// All pointer args must be device-resident and remain valid for the
/// duration of the launch. `stream` must be a live CUDA stream in
/// the current context. `a`, `b`, and `y` must each point to at
/// least `numel` `__nv_bfloat16`s of device memory.
pub fn baracuda_kernels_binary_add_bf16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_add_bf16`.
///
/// # Safety
/// Same pointer-validity contract as the corresponding `_run` fn,
/// but no device dereferences occur — only host-side checks.
pub fn baracuda_kernels_binary_add_bf16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `add`, f64 dtype, contiguous fast path.
///
/// # Safety
/// All pointer args must be device-resident and remain valid for the
/// duration of the launch. `stream` must be a live CUDA stream in
/// the current context. `a`, `b`, and `y` must each point to at
/// least `numel` `double`s of device memory.
pub fn baracuda_kernels_binary_add_f64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_add_f64`.
///
/// # Safety
/// Same pointer-validity contract as the corresponding `_run` fn,
/// but no device dereferences occur — only host-side checks.
pub fn baracuda_kernels_binary_add_f64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `sub`, f16 dtype, contiguous fast path.
///
/// # Safety
/// All pointer args must be device-resident and remain valid for the
/// duration of the launch. `stream` must be a live CUDA stream in
/// the current context. `a`, `b`, and `y` must each point to at
/// least `numel` `__half`s of device memory.
pub fn baracuda_kernels_binary_sub_f16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_sub_f16`.
///
/// # Safety
/// Same pointer-validity contract as the corresponding `_run` fn,
/// but no device dereferences occur — only host-side checks.
pub fn baracuda_kernels_binary_sub_f16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `sub`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// All pointer args must be device-resident and remain valid for the
/// duration of the launch. `stream` must be a live CUDA stream in
/// the current context. `a`, `b`, and `y` must each point to at
/// least `numel` `__nv_bfloat16`s of device memory.
pub fn baracuda_kernels_binary_sub_bf16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_sub_bf16`.
///
/// # Safety
/// Same pointer-validity contract as the corresponding `_run` fn,
/// but no device dereferences occur — only host-side checks.
pub fn baracuda_kernels_binary_sub_bf16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `sub`, f64 dtype, contiguous fast path.
///
/// # Safety
/// All pointer args must be device-resident and remain valid for the
/// duration of the launch. `stream` must be a live CUDA stream in
/// the current context. `a`, `b`, and `y` must each point to at
/// least `numel` `double`s of device memory.
pub fn baracuda_kernels_binary_sub_f64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_sub_f64`.
///
/// # Safety
/// Same pointer-validity contract as the corresponding `_run` fn,
/// but no device dereferences occur — only host-side checks.
pub fn baracuda_kernels_binary_sub_f64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `mul`, f16 dtype, contiguous fast path.
///
/// # Safety
/// All pointer args must be device-resident and remain valid for the
/// duration of the launch. `stream` must be a live CUDA stream in
/// the current context. `a`, `b`, and `y` must each point to at
/// least `numel` `__half`s of device memory.
pub fn baracuda_kernels_binary_mul_f16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_mul_f16`.
///
/// # Safety
/// Same pointer-validity contract as the corresponding `_run` fn,
/// but no device dereferences occur — only host-side checks.
pub fn baracuda_kernels_binary_mul_f16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `mul`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// All pointer args must be device-resident and remain valid for the
/// duration of the launch. `stream` must be a live CUDA stream in
/// the current context. `a`, `b`, and `y` must each point to at
/// least `numel` `__nv_bfloat16`s of device memory.
pub fn baracuda_kernels_binary_mul_bf16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_mul_bf16`.
///
/// # Safety
/// Same pointer-validity contract as the corresponding `_run` fn,
/// but no device dereferences occur — only host-side checks.
pub fn baracuda_kernels_binary_mul_bf16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `mul`, f64 dtype, contiguous fast path.
///
/// # Safety
/// All pointer args must be device-resident and remain valid for the
/// duration of the launch. `stream` must be a live CUDA stream in
/// the current context. `a`, `b`, and `y` must each point to at
/// least `numel` `double`s of device memory.
pub fn baracuda_kernels_binary_mul_f64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_mul_f64`.
///
/// # Safety
/// Same pointer-validity contract as the corresponding `_run` fn,
/// but no device dereferences occur — only host-side checks.
pub fn baracuda_kernels_binary_mul_f64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `div`, f16 dtype, contiguous fast path.
///
/// # Safety
/// All pointer args must be device-resident and remain valid for the
/// duration of the launch. `stream` must be a live CUDA stream in
/// the current context. `a`, `b`, and `y` must each point to at
/// least `numel` `__half`s of device memory.
pub fn baracuda_kernels_binary_div_f16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_div_f16`.
///
/// # Safety
/// Same pointer-validity contract as the corresponding `_run` fn,
/// but no device dereferences occur — only host-side checks.
pub fn baracuda_kernels_binary_div_f16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `div`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// All pointer args must be device-resident and remain valid for the
/// duration of the launch. `stream` must be a live CUDA stream in
/// the current context. `a`, `b`, and `y` must each point to at
/// least `numel` `__nv_bfloat16`s of device memory.
pub fn baracuda_kernels_binary_div_bf16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_div_bf16`.
///
/// # Safety
/// Same pointer-validity contract as the corresponding `_run` fn,
/// but no device dereferences occur — only host-side checks.
pub fn baracuda_kernels_binary_div_bf16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `div`, f64 dtype, contiguous fast path.
///
/// # Safety
/// All pointer args must be device-resident and remain valid for the
/// duration of the launch. `stream` must be a live CUDA stream in
/// the current context. `a`, `b`, and `y` must each point to at
/// least `numel` `double`s of device memory.
pub fn baracuda_kernels_binary_div_f64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_div_f64`.
///
/// # Safety
/// Same pointer-validity contract as the corresponding `_run` fn,
/// but no device dereferences occur — only host-side checks.
pub fn baracuda_kernels_binary_div_f64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
// ----- Binary pow (`y = a^b`), contig fast path ------------------
// f16 / bf16 transcend through an f32 detour inside the kernel; f32
// uses `powf`, f64 uses `pow`. All four dtypes share the same ABI as
// the other binary contig launchers (numel, a, b, y, ws, ws_bytes,
// stream).
/// Binary `pow`, f32, contig.
pub fn baracuda_kernels_binary_pow_f32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `pow`, f32, can-implement.
pub fn baracuda_kernels_binary_pow_f32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `pow`, f16, contig.
pub fn baracuda_kernels_binary_pow_f16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `pow`, f16, can-implement.
pub fn baracuda_kernels_binary_pow_f16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `pow`, bf16, contig.
pub fn baracuda_kernels_binary_pow_bf16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `pow`, bf16, can-implement.
pub fn baracuda_kernels_binary_pow_bf16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `pow`, f64, contig.
pub fn baracuda_kernels_binary_pow_f64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `pow`, f64, can-implement.
pub fn baracuda_kernels_binary_pow_f64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
// ----- Binary atan2 (`y = atan2(a, b)`), contig fast path --------
// f16 / bf16 transcend through an f32 detour inside the kernel; f32
// uses `atan2f`, f64 uses `atan2`.
/// Binary `atan2`, f32, contig.
pub fn baracuda_kernels_binary_atan2_f32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `atan2`, f32, can-implement.
pub fn baracuda_kernels_binary_atan2_f32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `atan2`, f16, contig.
pub fn baracuda_kernels_binary_atan2_f16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `atan2`, f16, can-implement.
pub fn baracuda_kernels_binary_atan2_f16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `atan2`, bf16, contig.
pub fn baracuda_kernels_binary_atan2_bf16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `atan2`, bf16, can-implement.
pub fn baracuda_kernels_binary_atan2_bf16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `atan2`, f64, contig.
pub fn baracuda_kernels_binary_atan2_f64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `atan2`, f64, can-implement.
pub fn baracuda_kernels_binary_atan2_f64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
// ----- Binary hypot (`y = sqrt(a² + b²)`), contig fast path ------
// f16 / bf16 transcend through an f32 detour inside the kernel; f32
// uses `hypotf`, f64 uses `hypot`. Both libdevice intrinsics are
// overflow-/underflow-safe (internally rescale by max(|a|, |b|)).
/// Binary `hypot`, f32, contig.
pub fn baracuda_kernels_binary_hypot_f32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `hypot`, f32, can-implement.
pub fn baracuda_kernels_binary_hypot_f32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `hypot`, f16, contig.
pub fn baracuda_kernels_binary_hypot_f16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `hypot`, f16, can-implement.
pub fn baracuda_kernels_binary_hypot_f16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `hypot`, bf16, contig.
pub fn baracuda_kernels_binary_hypot_bf16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `hypot`, bf16, can-implement.
pub fn baracuda_kernels_binary_hypot_bf16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `hypot`, f64, contig.
pub fn baracuda_kernels_binary_hypot_f64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `hypot`, f64, can-implement.
pub fn baracuda_kernels_binary_hypot_f64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
// ----- Binary copysign (`y = copysign(a, b) = |a| · sign(b)`), contig
// f16 / bf16 transcend through an f32 detour inside the kernel; f32
// uses `copysignf`, f64 uses `copysign`. Pure sign-bit manipulation —
// well-defined for every IEEE input including NaN.
/// Binary `copysign`, f32, contig.
pub fn baracuda_kernels_binary_copysign_f32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `copysign`, f32, can-implement.
pub fn baracuda_kernels_binary_copysign_f32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `copysign`, f16, contig.
pub fn baracuda_kernels_binary_copysign_f16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `copysign`, f16, can-implement.
pub fn baracuda_kernels_binary_copysign_f16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `copysign`, bf16, contig.
pub fn baracuda_kernels_binary_copysign_bf16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `copysign`, bf16, can-implement.
pub fn baracuda_kernels_binary_copysign_bf16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `copysign`, f64, contig.
pub fn baracuda_kernels_binary_copysign_f64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `copysign`, f64, can-implement.
pub fn baracuda_kernels_binary_copysign_f64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
// ----- Binary nextafter (`y = nextafter(a, b)`), contig fast path -
// f32 → `nextafterf`, f64 → `nextafter`. f16 / bf16 use direct
// bit-pattern manipulation (no f32 detour — adjacent half cells
// round-trip through f32 to themselves, so a naive f32 detour
// returns `a`, not its neighbor).
/// Binary `nextafter`, f32, contig.
pub fn baracuda_kernels_binary_nextafter_f32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `nextafter`, f32, can-implement.
pub fn baracuda_kernels_binary_nextafter_f32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `nextafter`, f16, contig.
pub fn baracuda_kernels_binary_nextafter_f16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `nextafter`, f16, can-implement.
pub fn baracuda_kernels_binary_nextafter_f16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `nextafter`, bf16, contig.
pub fn baracuda_kernels_binary_nextafter_bf16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `nextafter`, bf16, can-implement.
pub fn baracuda_kernels_binary_nextafter_bf16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `nextafter`, f64, contig.
pub fn baracuda_kernels_binary_nextafter_f64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `nextafter`, f64, can-implement.
pub fn baracuda_kernels_binary_nextafter_f64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
// ----- Binary fmin (`y = fmin(a, b)` IEEE 754, NaN-aware), contig --
// Distinct from `BinaryKind::Minimum` which propagates NaN. f32 →
// `fminf`, f64 → `fmin`, f16 / bf16 → f32-detour.
/// Binary `fmin`, f32, contig.
pub fn baracuda_kernels_binary_fmin_f32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `fmin`, f32, can-implement.
pub fn baracuda_kernels_binary_fmin_f32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `fmin`, f16, contig.
pub fn baracuda_kernels_binary_fmin_f16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `fmin`, f16, can-implement.
pub fn baracuda_kernels_binary_fmin_f16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `fmin`, bf16, contig.
pub fn baracuda_kernels_binary_fmin_bf16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `fmin`, bf16, can-implement.
pub fn baracuda_kernels_binary_fmin_bf16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `fmin`, f64, contig.
pub fn baracuda_kernels_binary_fmin_f64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `fmin`, f64, can-implement.
pub fn baracuda_kernels_binary_fmin_f64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
// ----- Binary fmax (`y = fmax(a, b)` IEEE 754, NaN-aware), contig --
// Distinct from `BinaryKind::Maximum` which propagates NaN. f32 →
// `fmaxf`, f64 → `fmax`, f16 / bf16 → f32-detour.
/// Binary `fmax`, f32, contig.
pub fn baracuda_kernels_binary_fmax_f32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `fmax`, f32, can-implement.
pub fn baracuda_kernels_binary_fmax_f32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `fmax`, f16, contig.
pub fn baracuda_kernels_binary_fmax_f16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `fmax`, f16, can-implement.
pub fn baracuda_kernels_binary_fmax_f16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `fmax`, bf16, contig.
pub fn baracuda_kernels_binary_fmax_bf16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `fmax`, bf16, can-implement.
pub fn baracuda_kernels_binary_fmax_bf16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `fmax`, f64, contig.
pub fn baracuda_kernels_binary_fmax_f64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `fmax`, f64, can-implement.
pub fn baracuda_kernels_binary_fmax_f64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
// ----- Binary maximum (`y = max(a, b)` NaN-PROPAGATING), contig ----
// Distinct from `BinaryKind::Fmax` which is NaN-aware (NaN-ignored).
// Any NaN input produces a NaN output, matching `torch.maximum`.
// f32 / f64 → compare-and-select with explicit NaN guards;
// f16 / bf16 → f32-detour with same NaN guard.
/// Binary `maximum`, f32, contig.
pub fn baracuda_kernels_binary_maximum_f32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `maximum`, f32, can-implement.
pub fn baracuda_kernels_binary_maximum_f32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `maximum`, f16, contig.
pub fn baracuda_kernels_binary_maximum_f16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `maximum`, f16, can-implement.
pub fn baracuda_kernels_binary_maximum_f16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `maximum`, bf16, contig.
pub fn baracuda_kernels_binary_maximum_bf16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `maximum`, bf16, can-implement.
pub fn baracuda_kernels_binary_maximum_bf16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `maximum`, f64, contig.
pub fn baracuda_kernels_binary_maximum_f64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `maximum`, f64, can-implement.
pub fn baracuda_kernels_binary_maximum_f64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
// ----- Binary minimum (`y = min(a, b)` NaN-PROPAGATING), contig ----
// Distinct from `BinaryKind::Fmin` which is NaN-aware (NaN-ignored).
// Any NaN input produces a NaN output, matching `torch.minimum`.
/// Binary `minimum`, f32, contig.
pub fn baracuda_kernels_binary_minimum_f32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `minimum`, f32, can-implement.
pub fn baracuda_kernels_binary_minimum_f32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `minimum`, f16, contig.
pub fn baracuda_kernels_binary_minimum_f16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `minimum`, f16, can-implement.
pub fn baracuda_kernels_binary_minimum_f16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `minimum`, bf16, contig.
pub fn baracuda_kernels_binary_minimum_bf16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `minimum`, bf16, can-implement.
pub fn baracuda_kernels_binary_minimum_bf16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `minimum`, f64, contig.
pub fn baracuda_kernels_binary_minimum_f64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `minimum`, f64, can-implement.
pub fn baracuda_kernels_binary_minimum_f64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
// ----- Binary floor_divide (`y = floor(a / b)`), contig ------------
/// Binary `floor_divide`, f32, contig.
pub fn baracuda_kernels_binary_floor_divide_f32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `floor_divide`, f32, can-implement.
pub fn baracuda_kernels_binary_floor_divide_f32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `floor_divide`, f16, contig.
pub fn baracuda_kernels_binary_floor_divide_f16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `floor_divide`, f16, can-implement.
pub fn baracuda_kernels_binary_floor_divide_f16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `floor_divide`, bf16, contig.
pub fn baracuda_kernels_binary_floor_divide_bf16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `floor_divide`, bf16, can-implement.
pub fn baracuda_kernels_binary_floor_divide_bf16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `floor_divide`, f64, contig.
pub fn baracuda_kernels_binary_floor_divide_f64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `floor_divide`, f64, can-implement.
pub fn baracuda_kernels_binary_floor_divide_f64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
// ----- Binary mod (`y = a - floor(a/b)*b`, sign of b), contig -------
// Python-style modulo. Distinct from `BinaryKind::Remainder`, which
// is C-style (sign of a).
/// Binary `mod`, f32, contig.
pub fn baracuda_kernels_binary_mod_f32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `mod`, f32, can-implement.
pub fn baracuda_kernels_binary_mod_f32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `mod`, f16, contig.
pub fn baracuda_kernels_binary_mod_f16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `mod`, f16, can-implement.
pub fn baracuda_kernels_binary_mod_f16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `mod`, bf16, contig.
pub fn baracuda_kernels_binary_mod_bf16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `mod`, bf16, can-implement.
pub fn baracuda_kernels_binary_mod_bf16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `mod`, f64, contig.
pub fn baracuda_kernels_binary_mod_f64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `mod`, f64, can-implement.
pub fn baracuda_kernels_binary_mod_f64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
// ----- Binary remainder (`y = fmod(a, b)`, sign of a), contig ------
// C-style remainder via libdevice `fmodf` / `fmod`. Distinct from
// `BinaryKind::Mod`, which is Python-style (sign of b).
/// Binary `remainder`, f32, contig.
pub fn baracuda_kernels_binary_remainder_f32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `remainder`, f32, can-implement.
pub fn baracuda_kernels_binary_remainder_f32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `remainder`, f16, contig.
pub fn baracuda_kernels_binary_remainder_f16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `remainder`, f16, can-implement.
pub fn baracuda_kernels_binary_remainder_f16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `remainder`, bf16, contig.
pub fn baracuda_kernels_binary_remainder_bf16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `remainder`, bf16, can-implement.
pub fn baracuda_kernels_binary_remainder_bf16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `remainder`, f64, contig.
pub fn baracuda_kernels_binary_remainder_f64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary `remainder`, f64, can-implement.
pub fn baracuda_kernels_binary_remainder_f64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
}
// ============================================================================
// Elementwise — integer / bool binary ops (contig only)
// ============================================================================
//
// Phase 3.3 integer + bool fanout. Five bitwise ops (`and` / `or` /
// `xor` / `left_shift` / `right_shift`) over `{i32, i64}` plus three
// logical ops (`and` / `or` / `xor`) over `Bool` (1-byte storage).
//
// **Contig only.** Strided / broadcast variants are deferred to a
// later milestone — the caller is expected to materialize a contiguous
// operand if it needs broadcast semantics for these op families. The
// Rust dispatcher therefore routes any non-contig launch through
// `Error::Unsupported` for these (kind, dtype) cells.
//
// Right-shift on signed integers is **arithmetic** (sign-extending),
// matching PyTorch's contract — see
// `kernels/elementwise/binary_bitwise_right_shift_int.cu` for the
// portability rationale.
//
// Logical ops normalize their inputs to canonical 0 / 1 before
// applying the boolean op so the output is always strictly 0 or 1
// even when the inputs are unnormalized byte storage.
//
// ABI mirrors the FP contig binary launchers (numel + 3 pointers +
// workspace + stream). Status codes are shared with the GEMM family
// (see crate-level doc).
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
// -------- bitwise_and --------
/// Binary bitwise `and`, i32 dtype, contig.
///
/// # Safety
/// All pointer args must be device-resident and remain valid for
/// the duration of the launch. `stream` must be a live CUDA stream
/// in the current context. `a`, `b`, and `y` must each point to at
/// least `numel` `int32_t`s of device memory.
pub fn baracuda_kernels_binary_bitwise_and_i32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary bitwise `and`, i32 dtype, can-implement.
pub fn baracuda_kernels_binary_bitwise_and_i32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary bitwise `and`, i64 dtype, contig.
///
/// # Safety
/// Same contract as
/// `baracuda_kernels_binary_bitwise_and_i32_run`, but each tensor
/// covers at least `numel` `int64_t`s.
pub fn baracuda_kernels_binary_bitwise_and_i64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary bitwise `and`, i64 dtype, can-implement.
pub fn baracuda_kernels_binary_bitwise_and_i64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
// -------- bitwise_or --------
/// Binary bitwise `or`, i32 dtype, contig.
///
/// # Safety
/// Same contract as `baracuda_kernels_binary_bitwise_and_i32_run`.
pub fn baracuda_kernels_binary_bitwise_or_i32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary bitwise `or`, i32 dtype, can-implement.
pub fn baracuda_kernels_binary_bitwise_or_i32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary bitwise `or`, i64 dtype, contig.
///
/// # Safety
/// Same contract as `baracuda_kernels_binary_bitwise_and_i64_run`.
pub fn baracuda_kernels_binary_bitwise_or_i64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary bitwise `or`, i64 dtype, can-implement.
pub fn baracuda_kernels_binary_bitwise_or_i64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
// -------- bitwise_xor --------
/// Binary bitwise `xor`, i32 dtype, contig.
///
/// # Safety
/// Same contract as `baracuda_kernels_binary_bitwise_and_i32_run`.
pub fn baracuda_kernels_binary_bitwise_xor_i32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary bitwise `xor`, i32 dtype, can-implement.
pub fn baracuda_kernels_binary_bitwise_xor_i32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary bitwise `xor`, i64 dtype, contig.
///
/// # Safety
/// Same contract as `baracuda_kernels_binary_bitwise_and_i64_run`.
pub fn baracuda_kernels_binary_bitwise_xor_i64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary bitwise `xor`, i64 dtype, can-implement.
pub fn baracuda_kernels_binary_bitwise_xor_i64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
// -------- bitwise_left_shift --------
/// Binary bitwise `left_shift`, i32 dtype, contig.
///
/// `y = a << b`. Out-of-range shift amounts inherit the host
/// architecture's behavior — callers requiring defined behavior
/// should clamp `b` to `[0, 31]` themselves.
///
/// # Safety
/// Same contract as `baracuda_kernels_binary_bitwise_and_i32_run`.
pub fn baracuda_kernels_binary_bitwise_left_shift_i32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary bitwise `left_shift`, i32 dtype, can-implement.
pub fn baracuda_kernels_binary_bitwise_left_shift_i32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary bitwise `left_shift`, i64 dtype, contig.
///
/// `y = a << b`. Out-of-range shift amounts inherit the host
/// architecture's behavior — callers requiring defined behavior
/// should clamp `b` to `[0, 63]` themselves.
///
/// # Safety
/// Same contract as `baracuda_kernels_binary_bitwise_and_i64_run`.
pub fn baracuda_kernels_binary_bitwise_left_shift_i64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary bitwise `left_shift`, i64 dtype, can-implement.
pub fn baracuda_kernels_binary_bitwise_left_shift_i64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
// -------- bitwise_right_shift --------
/// Binary bitwise `right_shift`, i32 dtype, contig. **Arithmetic**
/// shift (sign-extending), matching PyTorch.
///
/// `y = a >> b`. Out-of-range shift amounts inherit the host
/// architecture's behavior.
///
/// # Safety
/// Same contract as `baracuda_kernels_binary_bitwise_and_i32_run`.
pub fn baracuda_kernels_binary_bitwise_right_shift_i32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary bitwise `right_shift`, i32 dtype, can-implement.
pub fn baracuda_kernels_binary_bitwise_right_shift_i32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary bitwise `right_shift`, i64 dtype, contig. **Arithmetic**
/// shift (sign-extending), matching PyTorch.
///
/// # Safety
/// Same contract as `baracuda_kernels_binary_bitwise_and_i64_run`.
pub fn baracuda_kernels_binary_bitwise_right_shift_i64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary bitwise `right_shift`, i64 dtype, can-implement.
pub fn baracuda_kernels_binary_bitwise_right_shift_i64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
// -------- logical_and --------
/// Binary logical `and`, Bool dtype (1-byte storage), contig.
///
/// Truthiness convention: 0 = false, any non-zero = true. The
/// kernel normalizes each input before applying `&&`, so the output
/// is always strictly 0 or 1.
///
/// # Safety
/// All pointer args must be device-resident and remain valid for
/// the duration of the launch. `stream` must be a live CUDA stream
/// in the current context. `a`, `b`, and `y` must each point to at
/// least `numel` bytes of device memory.
pub fn baracuda_kernels_binary_logical_and_bool_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary logical `and`, Bool dtype, can-implement.
pub fn baracuda_kernels_binary_logical_and_bool_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
// -------- logical_or --------
/// Binary logical `or`, Bool dtype, contig.
///
/// # Safety
/// Same contract as `baracuda_kernels_binary_logical_and_bool_run`.
pub fn baracuda_kernels_binary_logical_or_bool_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary logical `or`, Bool dtype, can-implement.
pub fn baracuda_kernels_binary_logical_or_bool_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
// -------- logical_xor --------
/// Binary logical `xor`, Bool dtype, contig.
///
/// # Safety
/// Same contract as `baracuda_kernels_binary_logical_and_bool_run`.
pub fn baracuda_kernels_binary_logical_xor_bool_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Binary logical `xor`, Bool dtype, can-implement.
pub fn baracuda_kernels_binary_logical_xor_bool_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
}
// ============================================================================
// Elementwise — strided / broadcast variants
// ============================================================================
//
// Companion launchers to the contig fast path above. The Rust dispatcher
// picks contig vs strided at launch time based on whether all three
// operands are fully contiguous. The strided kernel handles every
// non-contig case: broadcast (stride 0 along an axis), transposed
// views, arbitrary strided slices.
//
// ABI:
// numel — i64 element count of the OUTPUT tensor
// (product of `shape`).
// rank — i32, number of valid axes in [0, 8].
// shape — points to `[i32; rank]` on the host stack.
// The OUTPUT shape; operands `a` and `b` are
// read at the same coords via their own
// strides (broadcast = stride 0).
// stride_a / b / y — points to `[i64; rank]` on the host stack,
// the per-axis element stride for each tensor.
// A stride of 0 along axis d marks a broadcast
// operand. Output stride is typically
// contiguous but the kernel accepts arbitrary
// strides.
// a / b — input device pointers (T const*).
// y — output device pointer (T*). Aliasing is
// safe in the contig case (i ≤ N); in the
// strided / broadcast case it's caller-
// responsibility — the kernel reads each
// (off_a, off_b) once before writing each
// off_y once, but stride-0 broadcast means
// many writes to the same off_y if the output
// is also broadcast, which is undefined.
// workspace / bytes — unused; pass null + 0 from Rust.
// stream — cudaStream_t cast to `*mut c_void`.
//
// Status codes mirror the GEMM family (see crate-level doc).
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Binary elementwise `add`, f32 dtype, strided / broadcast path.
/// This is the binary-strided trailblazer — its safety contract
/// (including aliasing) carries over to every other binary strided
/// launcher across all dtypes.
///
/// # Safety
/// All device pointer args must be device-resident and remain valid
/// for the duration of the launch. `shape` / `stride_*` are
/// host-side pointers to arrays of at least `rank` elements that
/// remain valid for the duration of the host-side launch call
/// (the launcher copies them into the kernel parameter block
/// before returning — they may be freed after the host call
/// completes, before the kernel completes on device).
///
/// **Aliasing (Phase 62)**: aliasing `y` with either input
/// (`a == y` or `b == y`) is safe IF AND ONLY IF the aliased
/// input's stride array equals `stride_y` element-for-element
/// (use [`baracuda_kernels_types::strides_equal`] to check). With
/// equal strides, each thread reads its own `off_y` cell from the
/// aliased input then writes the same cell, identical structure
/// to the contig binary case. With unequal strides, different
/// threads can read cells that other threads have already
/// overwritten — silent data corruption. The kernel does no
/// validation; this is the caller's contract. The `__restrict__`
/// qualifiers on the kernel signature are an optimizer hint
/// (additional reordering freedom) — they are safe to violate
/// only when the per-thread access pattern remains read-then-write
/// at the same cell, i.e., when strides are equal.
///
/// Additional preconditions (apply with or without aliasing): no
/// zero strides on `y`, and `(shape, stride_y)` must specify a
/// valid permutation (no two linear `i` values mapping to the
/// same `off_y` cell).
///
/// This contract is stable across baracuda versions.
pub fn baracuda_kernels_binary_add_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_add_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_add_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `add`, f16 dtype, strided / broadcast path.
///
/// # Safety
/// Same contract as `baracuda_kernels_binary_add_f32_strided_run`.
pub fn baracuda_kernels_binary_add_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_add_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_add_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `add`, bf16 dtype, strided / broadcast path.
///
/// # Safety
/// Same contract as `baracuda_kernels_binary_add_f32_strided_run`.
pub fn baracuda_kernels_binary_add_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_add_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_add_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `add`, f64 dtype, strided / broadcast path.
///
/// # Safety
/// Same contract as `baracuda_kernels_binary_add_f32_strided_run`.
pub fn baracuda_kernels_binary_add_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_add_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_add_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `sub`, f32 dtype, strided / broadcast path.
///
/// # Safety
/// Same contract as `baracuda_kernels_binary_add_f32_strided_run`.
pub fn baracuda_kernels_binary_sub_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_sub_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_sub_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `sub`, f16 dtype, strided / broadcast path.
///
/// # Safety
/// Same contract as `baracuda_kernels_binary_add_f32_strided_run`.
pub fn baracuda_kernels_binary_sub_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_sub_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_sub_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `sub`, bf16 dtype, strided / broadcast path.
///
/// # Safety
/// Same contract as `baracuda_kernels_binary_add_f32_strided_run`.
pub fn baracuda_kernels_binary_sub_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_sub_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_sub_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `sub`, f64 dtype, strided / broadcast path.
///
/// # Safety
/// Same contract as `baracuda_kernels_binary_add_f32_strided_run`.
pub fn baracuda_kernels_binary_sub_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_sub_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_sub_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `mul`, f32 dtype, strided / broadcast path.
///
/// # Safety
/// Same contract as `baracuda_kernels_binary_add_f32_strided_run`.
pub fn baracuda_kernels_binary_mul_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_mul_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_mul_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `mul`, f16 dtype, strided / broadcast path.
///
/// # Safety
/// Same contract as `baracuda_kernels_binary_add_f32_strided_run`.
pub fn baracuda_kernels_binary_mul_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_mul_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_mul_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `mul`, bf16 dtype, strided / broadcast path.
///
/// # Safety
/// Same contract as `baracuda_kernels_binary_add_f32_strided_run`.
pub fn baracuda_kernels_binary_mul_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_mul_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_mul_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `mul`, f64 dtype, strided / broadcast path.
///
/// # Safety
/// Same contract as `baracuda_kernels_binary_add_f32_strided_run`.
pub fn baracuda_kernels_binary_mul_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_mul_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_mul_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `div`, f32 dtype, strided / broadcast path.
///
/// # Safety
/// Same contract as `baracuda_kernels_binary_add_f32_strided_run`.
pub fn baracuda_kernels_binary_div_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_div_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_div_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `div`, f16 dtype, strided / broadcast path.
///
/// # Safety
/// Same contract as `baracuda_kernels_binary_add_f32_strided_run`.
pub fn baracuda_kernels_binary_div_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_div_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_div_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `div`, bf16 dtype, strided / broadcast path.
///
/// # Safety
/// Same contract as `baracuda_kernels_binary_add_f32_strided_run`.
pub fn baracuda_kernels_binary_div_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_div_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_div_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `div`, f64 dtype, strided / broadcast path.
///
/// # Safety
/// Same contract as `baracuda_kernels_binary_add_f32_strided_run`.
pub fn baracuda_kernels_binary_div_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_div_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_div_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `pow`, f32, strided.
pub fn baracuda_kernels_binary_pow_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_pow_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_pow_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `pow`, f16, strided.
pub fn baracuda_kernels_binary_pow_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_pow_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_pow_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `pow`, bf16, strided.
pub fn baracuda_kernels_binary_pow_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_pow_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_pow_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `pow`, f64, strided.
pub fn baracuda_kernels_binary_pow_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_pow_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_pow_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `atan2`, f32, strided.
pub fn baracuda_kernels_binary_atan2_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_atan2_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_atan2_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `atan2`, f16, strided.
pub fn baracuda_kernels_binary_atan2_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_atan2_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_atan2_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `atan2`, bf16, strided.
pub fn baracuda_kernels_binary_atan2_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_atan2_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_atan2_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `atan2`, f64, strided.
pub fn baracuda_kernels_binary_atan2_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_atan2_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_atan2_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `hypot`, f32, strided.
pub fn baracuda_kernels_binary_hypot_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_hypot_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_hypot_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `hypot`, f16, strided.
pub fn baracuda_kernels_binary_hypot_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_hypot_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_hypot_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `hypot`, bf16, strided.
pub fn baracuda_kernels_binary_hypot_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_hypot_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_hypot_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `hypot`, f64, strided.
pub fn baracuda_kernels_binary_hypot_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_hypot_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_hypot_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `copysign`, f32, strided.
pub fn baracuda_kernels_binary_copysign_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_copysign_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_copysign_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `copysign`, f16, strided.
pub fn baracuda_kernels_binary_copysign_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_copysign_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_copysign_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `copysign`, bf16, strided.
pub fn baracuda_kernels_binary_copysign_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_copysign_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_copysign_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `copysign`, f64, strided.
pub fn baracuda_kernels_binary_copysign_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_copysign_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_copysign_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `nextafter`, f32, strided.
pub fn baracuda_kernels_binary_nextafter_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_nextafter_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_nextafter_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `nextafter`, f16, strided.
pub fn baracuda_kernels_binary_nextafter_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_nextafter_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_nextafter_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `nextafter`, bf16, strided.
pub fn baracuda_kernels_binary_nextafter_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_nextafter_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_nextafter_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `nextafter`, f64, strided.
pub fn baracuda_kernels_binary_nextafter_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_nextafter_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_nextafter_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `fmin`, f32, strided.
pub fn baracuda_kernels_binary_fmin_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_fmin_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_fmin_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `fmin`, f16, strided.
pub fn baracuda_kernels_binary_fmin_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_fmin_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_fmin_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `fmin`, bf16, strided.
pub fn baracuda_kernels_binary_fmin_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_fmin_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_fmin_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `fmin`, f64, strided.
pub fn baracuda_kernels_binary_fmin_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_fmin_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_fmin_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `fmax`, f32, strided.
pub fn baracuda_kernels_binary_fmax_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_fmax_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_fmax_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `fmax`, f16, strided.
pub fn baracuda_kernels_binary_fmax_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_fmax_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_fmax_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `fmax`, bf16, strided.
pub fn baracuda_kernels_binary_fmax_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_fmax_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_fmax_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `fmax`, f64, strided.
pub fn baracuda_kernels_binary_fmax_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_fmax_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_fmax_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
// ----- Binary maximum (NaN-PROPAGATING), strided -------------------
/// Binary `maximum`, f32, strided.
pub fn baracuda_kernels_binary_maximum_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_maximum_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_maximum_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `maximum`, f16, strided.
pub fn baracuda_kernels_binary_maximum_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_maximum_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_maximum_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `maximum`, bf16, strided.
pub fn baracuda_kernels_binary_maximum_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_maximum_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_maximum_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `maximum`, f64, strided.
pub fn baracuda_kernels_binary_maximum_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_maximum_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_maximum_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
// ----- Binary minimum (NaN-PROPAGATING), strided -------------------
/// Binary `minimum`, f32, strided.
pub fn baracuda_kernels_binary_minimum_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_minimum_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_minimum_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `minimum`, f16, strided.
pub fn baracuda_kernels_binary_minimum_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_minimum_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_minimum_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `minimum`, bf16, strided.
pub fn baracuda_kernels_binary_minimum_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_minimum_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_minimum_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `minimum`, f64, strided.
pub fn baracuda_kernels_binary_minimum_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_minimum_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_minimum_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
// ----- Binary floor_divide, strided --------------------------------
/// Binary `floor_divide`, f32, strided.
pub fn baracuda_kernels_binary_floor_divide_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_floor_divide_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_floor_divide_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `floor_divide`, f16, strided.
pub fn baracuda_kernels_binary_floor_divide_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_floor_divide_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_floor_divide_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `floor_divide`, bf16, strided.
pub fn baracuda_kernels_binary_floor_divide_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_floor_divide_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_floor_divide_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `floor_divide`, f64, strided.
pub fn baracuda_kernels_binary_floor_divide_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_floor_divide_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_floor_divide_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
// ----- Binary mod (Python-style, sign of b), strided ---------------
/// Binary `mod`, f32, strided.
pub fn baracuda_kernels_binary_mod_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_mod_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_mod_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `mod`, f16, strided.
pub fn baracuda_kernels_binary_mod_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_mod_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_mod_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `mod`, bf16, strided.
pub fn baracuda_kernels_binary_mod_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_mod_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_mod_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `mod`, f64, strided.
pub fn baracuda_kernels_binary_mod_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_mod_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_mod_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
// ----- Binary remainder (C-style, sign of a), strided --------------
/// Binary `remainder`, f32, strided.
pub fn baracuda_kernels_binary_remainder_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_remainder_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_remainder_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `remainder`, f16, strided.
pub fn baracuda_kernels_binary_remainder_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_remainder_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_remainder_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `remainder`, bf16, strided.
pub fn baracuda_kernels_binary_remainder_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_remainder_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_remainder_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary `remainder`, f64, strided.
pub fn baracuda_kernels_binary_remainder_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_remainder_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_binary_remainder_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
}
// ============================================================================
// Elementwise — ternary (3→1) ops
// ============================================================================
//
// 3-input, 1-output pointwise ops with same-dtype operands. Same
// INSTANTIATE-driven kernel family as binary, with one extra input
// (`c`) and one extra stride array (`stride_c`) for the strided path.
//
// Wired matrix: {Clamp, Fma} × {f32, f16, bf16, f64} = 8 cells, each
// with contig + strided launchers (3 symbols per cell). {Addcmul,
// Addcdiv} are reserved-but-deferred — they take a scalar runtime
// parameter not yet representable in the ternary plan shape.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Ternary elementwise `clamp`, f32, contig fast path.
///
/// `y = min(max(a, b), c)` where `a` is the input, `b` is the lower
/// bound, `c` is the upper bound — matches PyTorch's
/// `torch.clamp(x, min=lo, max=hi)` semantics with `a = x`, `b = lo`,
/// `c = hi`. The caller is responsible for `lo <= hi`; if not, the
/// output is `hi` (PyTorch's convention: max wins).
///
/// # Safety
/// All device pointers must remain valid for the duration of the
/// launch. `a`, `b`, `c`, `y` must each point to at least `numel`
/// `float`s. Aliasing `y` with any input is safe — the kernel reads
/// each input cell before writing each output cell per thread.
pub fn baracuda_kernels_ternary_clamp_f32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `ternary_clamp_f32`.
pub fn baracuda_kernels_ternary_clamp_f32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *const c_void,
) -> i32;
/// Ternary elementwise `clamp`, f32, strided / broadcast path.
/// This is the ternary-strided trailblazer — its safety contract
/// (including aliasing) carries over to every ternary strided
/// launcher across all dtypes.
///
/// Handles non-contig views and broadcast — each input's
/// per-axis stride may be 0 (broadcast along that axis) or any
/// integer (transposed / sliced view). The PyTorch convention
/// `clamp(x, min=lo, max=hi)` typically has `lo` / `hi` as scalars
/// — represent them as rank-N tensors with `shape[d] = 1` and
/// `stride[d] = 0` on every axis.
///
/// **Aliasing (Phase 62)**: aliasing `y` with any input (`a == y`,
/// `b == y`, or `c == y`) is safe IF AND ONLY IF the aliased
/// input's stride array equals `stride_y` element-for-element
/// (use [`baracuda_kernels_types::strides_equal`] to check). With
/// equal strides, each thread reads its own `off_y` cell from the
/// aliased input then writes the same cell. With unequal strides
/// (including the common broadcast-of-scalar case where the input
/// stride has zeros), different threads can read cells that other
/// threads have already overwritten — silent data corruption. The
/// kernel does no validation; this is the caller's contract. Note
/// that broadcasting the lo/hi bounds via zero strides is itself
/// fine — what's NOT fine is aliasing a broadcast input pointer
/// with `y`.
///
/// Additional preconditions (apply with or without aliasing): no
/// zero strides on `y`, and `(shape, stride_y)` must specify a
/// valid permutation (no two linear `i` values mapping to the
/// same `off_y` cell).
///
/// This contract is stable across baracuda versions.
pub fn baracuda_kernels_ternary_clamp_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_c: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `ternary_clamp_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_ternary_clamp_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_c: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *const c_void,
) -> i32;
/// Ternary elementwise `clamp`, f16, contig fast path.
///
/// See `baracuda_kernels_ternary_clamp_f32_run`. Inputs and output
/// are `__half`; the kernel applies a single f32-detour
/// `fminf(fmaxf(...))` per cell — matches host
/// `half::f16` round-to-f32-min/max-round-back-to-f16 semantics.
pub fn baracuda_kernels_ternary_clamp_f16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `ternary_clamp_f16`.
pub fn baracuda_kernels_ternary_clamp_f16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *const c_void,
) -> i32;
/// Ternary elementwise `clamp`, f16, strided / broadcast path.
pub fn baracuda_kernels_ternary_clamp_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_c: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `ternary_clamp_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_ternary_clamp_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_c: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *const c_void,
) -> i32;
/// Ternary elementwise `clamp`, bf16, contig fast path.
///
/// See `baracuda_kernels_ternary_clamp_f32_run`. Same f32-detour
/// pipeline as the f16 variant but with `__nv_bfloat16` storage.
pub fn baracuda_kernels_ternary_clamp_bf16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `ternary_clamp_bf16`.
pub fn baracuda_kernels_ternary_clamp_bf16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *const c_void,
) -> i32;
/// Ternary elementwise `clamp`, bf16, strided / broadcast path.
pub fn baracuda_kernels_ternary_clamp_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_c: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `ternary_clamp_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_ternary_clamp_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_c: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *const c_void,
) -> i32;
/// Ternary elementwise `clamp`, f64, contig fast path.
///
/// See `baracuda_kernels_ternary_clamp_f32_run`. Inputs and output
/// are `double`; uses `fmin(fmax(...))` directly (no detour).
pub fn baracuda_kernels_ternary_clamp_f64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `ternary_clamp_f64`.
pub fn baracuda_kernels_ternary_clamp_f64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *const c_void,
) -> i32;
/// Ternary elementwise `clamp`, f64, strided / broadcast path.
pub fn baracuda_kernels_ternary_clamp_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_c: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `ternary_clamp_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_ternary_clamp_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_c: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *const c_void,
) -> i32;
}
// --- Fma --------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Ternary elementwise `fma`, f32, contig fast path.
///
/// `y = a * b + c` — computed as two separate rounding steps
/// (multiply then add), NOT the IEEE single-rounding fma. This
/// matches PyTorch's `torch.addcmul(c, a, b)` with implicit
/// `value=1` and gives bit-exact compare with the host reference's
/// `a * b + c` on f32 / f64. The f16 / bf16 variants follow the
/// usual f32-detour pattern (each scalar op promotes to f32, runs
/// once, rounds back).
///
/// # Safety
/// All device pointers must remain valid for the duration of the
/// launch. `a`, `b`, `c`, `y` must each point to at least `numel`
/// `float`s. Aliasing `y` with any input is safe — each thread
/// reads each input cell before writing the output cell.
pub fn baracuda_kernels_ternary_fma_f32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `ternary_fma_f32`.
pub fn baracuda_kernels_ternary_fma_f32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *const c_void,
) -> i32;
/// Ternary elementwise `fma`, f32, strided / broadcast path.
pub fn baracuda_kernels_ternary_fma_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_c: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `ternary_fma_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_ternary_fma_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_c: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *const c_void,
) -> i32;
/// Ternary elementwise `fma`, f16, contig fast path.
pub fn baracuda_kernels_ternary_fma_f16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `ternary_fma_f16`.
pub fn baracuda_kernels_ternary_fma_f16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *const c_void,
) -> i32;
/// Ternary elementwise `fma`, f16, strided / broadcast path.
pub fn baracuda_kernels_ternary_fma_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_c: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `ternary_fma_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_ternary_fma_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_c: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *const c_void,
) -> i32;
/// Ternary elementwise `fma`, bf16, contig fast path.
pub fn baracuda_kernels_ternary_fma_bf16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `ternary_fma_bf16`.
pub fn baracuda_kernels_ternary_fma_bf16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *const c_void,
) -> i32;
/// Ternary elementwise `fma`, bf16, strided / broadcast path.
pub fn baracuda_kernels_ternary_fma_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_c: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `ternary_fma_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_ternary_fma_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_c: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *const c_void,
) -> i32;
/// Ternary elementwise `fma`, f64, contig fast path.
pub fn baracuda_kernels_ternary_fma_f64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `ternary_fma_f64`.
pub fn baracuda_kernels_ternary_fma_f64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *const c_void,
) -> i32;
/// Ternary elementwise `fma`, f64, strided / broadcast path.
pub fn baracuda_kernels_ternary_fma_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_c: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `ternary_fma_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_ternary_fma_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_c: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
c: *const c_void,
y: *const c_void,
) -> i32;
}
// ============================================================================
// Reductions — Phase 4 trailblazer (axis reduction)
// ============================================================================
//
// `y = reduce(x, axis=k)` with keepdim=true (output shape == input
// shape but the reduced axis collapses to size 1). Single-axis only
// today — multi-axis / full-tensor reductions are fanout. Naive
// implementation: one thread per output cell.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Sum reduction along one axis, f32, naive thread-per-output-cell.
///
/// `output_shape` matches input shape with `[reduce_axis]` set to 1.
/// `reduce_extent` is the input's extent along the reduced axis.
/// `reduce_stride_x` is the input stride along the reduced axis
/// (in elements).
///
/// # Safety
/// All device pointers must remain valid for the launch. Host
/// arrays must remain valid for the host-side launch call.
pub fn baracuda_kernels_reduce_sum_f32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_sum_f32`.
pub fn baracuda_kernels_reduce_sum_f32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Sum reduction along one axis, f16.
///
/// Same parameter shape as the f32 variant; functor specializes the
/// accumulator op through the standard f32-detour pattern
/// (`__half2float` / `+` / `__float2half`).
///
/// # Safety
/// All device pointers must remain valid for the launch. Host
/// arrays must remain valid for the host-side launch call.
pub fn baracuda_kernels_reduce_sum_f16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_sum_f16`.
pub fn baracuda_kernels_reduce_sum_f16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Sum reduction along one axis, bf16 (f32-detour functor).
///
/// # Safety
/// All device pointers must remain valid for the launch. Host
/// arrays must remain valid for the host-side launch call.
pub fn baracuda_kernels_reduce_sum_bf16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_sum_bf16`.
pub fn baracuda_kernels_reduce_sum_bf16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Sum reduction along one axis, f64.
///
/// # Safety
/// All device pointers must remain valid for the launch. Host
/// arrays must remain valid for the host-side launch call.
pub fn baracuda_kernels_reduce_sum_f64_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_sum_f64`.
pub fn baracuda_kernels_reduce_sum_f64_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Mean reduction along one axis, f32. Sum then divide by extent.
///
/// # Safety
/// All device pointers must remain valid for the launch. Host
/// arrays must remain valid for the host-side launch call.
pub fn baracuda_kernels_reduce_mean_f32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_mean_f32`.
pub fn baracuda_kernels_reduce_mean_f32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Mean reduction along one axis, f16 (f32-detour for sum + divide).
///
/// # Safety
/// All device pointers must remain valid for the launch. Host
/// arrays must remain valid for the host-side launch call.
pub fn baracuda_kernels_reduce_mean_f16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_mean_f16`.
pub fn baracuda_kernels_reduce_mean_f16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Mean reduction along one axis, bf16 (f32-detour for sum + divide).
///
/// # Safety
/// All device pointers must remain valid for the launch. Host
/// arrays must remain valid for the host-side launch call.
pub fn baracuda_kernels_reduce_mean_bf16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_mean_bf16`.
pub fn baracuda_kernels_reduce_mean_bf16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Mean reduction along one axis, f64.
///
/// # Safety
/// All device pointers must remain valid for the launch. Host
/// arrays must remain valid for the host-side launch call.
pub fn baracuda_kernels_reduce_mean_f64_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_mean_f64`.
pub fn baracuda_kernels_reduce_mean_f64_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Max reduction along one axis, f32. `init = -INFINITY`, `fmaxf`.
///
/// # Safety
/// All device pointers must remain valid for the launch. Host
/// arrays must remain valid for the host-side launch call.
pub fn baracuda_kernels_reduce_max_f32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_max_f32`.
pub fn baracuda_kernels_reduce_max_f32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Max reduction along one axis, f16 (f32-detour fmaxf).
///
/// # Safety
/// All device pointers must remain valid for the launch. Host
/// arrays must remain valid for the host-side launch call.
pub fn baracuda_kernels_reduce_max_f16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_max_f16`.
pub fn baracuda_kernels_reduce_max_f16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Max reduction along one axis, bf16 (f32-detour fmaxf).
///
/// # Safety
/// All device pointers must remain valid for the launch. Host
/// arrays must remain valid for the host-side launch call.
pub fn baracuda_kernels_reduce_max_bf16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_max_bf16`.
pub fn baracuda_kernels_reduce_max_bf16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Max reduction along one axis, f64.
///
/// # Safety
/// All device pointers must remain valid for the launch. Host
/// arrays must remain valid for the host-side launch call.
pub fn baracuda_kernels_reduce_max_f64_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_max_f64`.
pub fn baracuda_kernels_reduce_max_f64_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Min reduction along one axis, f32. `init = +INFINITY`, `fminf`.
///
/// # Safety
/// All device pointers must remain valid for the launch. Host
/// arrays must remain valid for the host-side launch call.
pub fn baracuda_kernels_reduce_min_f32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_min_f32`.
pub fn baracuda_kernels_reduce_min_f32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Min reduction along one axis, f16 (f32-detour fminf).
///
/// # Safety
/// All device pointers must remain valid for the launch. Host
/// arrays must remain valid for the host-side launch call.
pub fn baracuda_kernels_reduce_min_f16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_min_f16`.
pub fn baracuda_kernels_reduce_min_f16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Min reduction along one axis, bf16 (f32-detour fminf).
///
/// # Safety
/// All device pointers must remain valid for the launch. Host
/// arrays must remain valid for the host-side launch call.
pub fn baracuda_kernels_reduce_min_bf16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_min_bf16`.
pub fn baracuda_kernels_reduce_min_bf16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Min reduction along one axis, f64.
///
/// # Safety
/// All device pointers must remain valid for the launch. Host
/// arrays must remain valid for the host-side launch call.
pub fn baracuda_kernels_reduce_min_f64_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_min_f64`.
pub fn baracuda_kernels_reduce_min_f64_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Product reduction along one axis, f32. `init = 1`, op = `*`.
///
/// # Safety
/// All device pointers must remain valid for the launch. Host
/// arrays must remain valid for the host-side launch call.
pub fn baracuda_kernels_reduce_prod_f32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_prod_f32`.
pub fn baracuda_kernels_reduce_prod_f32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Product reduction along one axis, f16 (f32-detour multiply).
///
/// # Safety
/// All device pointers must remain valid for the launch. Host
/// arrays must remain valid for the host-side launch call.
pub fn baracuda_kernels_reduce_prod_f16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_prod_f16`.
pub fn baracuda_kernels_reduce_prod_f16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Product reduction along one axis, bf16 (f32-detour multiply).
///
/// # Safety
/// All device pointers must remain valid for the launch. Host
/// arrays must remain valid for the host-side launch call.
pub fn baracuda_kernels_reduce_prod_bf16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_prod_bf16`.
pub fn baracuda_kernels_reduce_prod_bf16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Product reduction along one axis, f64.
///
/// # Safety
/// All device pointers must remain valid for the launch. Host
/// arrays must remain valid for the host-side launch call.
pub fn baracuda_kernels_reduce_prod_f64_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_prod_f64`.
pub fn baracuda_kernels_reduce_prod_f64_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
// ====================================================================
// Phase 37 Gap 1b — Integer-dtype single-axis Reduce family.
//
// Coverage: Sum / Min / Max / Prod × {u8, i8, u32, i16, i32, i64}.
//
// **Sum and Prod accumulator widening contract**: the internal
// accumulator is `i64` (signed dtypes) or `u64` (unsigned), and the
// result is **narrowed (wraps on overflow) back to the input
// dtype at the store site**. This matches Fuel's CPU reference,
// which performs the reduction in the same dtype as input/output
// and accepts the wrap. The widening only affects the bit-level
// result when the unwrapped infinite-precision answer happens to
// straddle multiple 2^N boundaries during accumulation — without
// widening, the partial-sum wrapping order would diverge from the
// CPU's left-to-right modulo-2^N accumulation. With widening, the
// GPU and CPU agree on the bits modulo 2^N.
//
// Min and Max use same-dtype throughout (no overflow concern).
/// `sum(x, axis=k)` with u8 input/output (wider u64 accumulator,
/// wrap-on-overflow narrow on store).
pub fn baracuda_kernels_reduce_sum_u8_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_sum_u8`.
pub fn baracuda_kernels_reduce_sum_u8_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `sum(x, axis=k)` with i8 input/output (wider i64 accumulator).
pub fn baracuda_kernels_reduce_sum_i8_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_sum_i8`.
pub fn baracuda_kernels_reduce_sum_i8_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `sum(x, axis=k)` with u32 input/output (wider u64 accumulator).
pub fn baracuda_kernels_reduce_sum_u32_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_sum_u32`.
pub fn baracuda_kernels_reduce_sum_u32_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `sum(x, axis=k)` with i16 input/output (wider i64 accumulator).
pub fn baracuda_kernels_reduce_sum_i16_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_sum_i16`.
pub fn baracuda_kernels_reduce_sum_i16_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `sum(x, axis=k)` with i32 input/output (wider i64 accumulator).
pub fn baracuda_kernels_reduce_sum_i32_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_sum_i32`.
pub fn baracuda_kernels_reduce_sum_i32_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `sum(x, axis=k)` with i64 input/output. Accumulator and output
/// share dtype; modulo-2^64 wrap is the natural device behaviour.
pub fn baracuda_kernels_reduce_sum_i64_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_sum_i64`.
pub fn baracuda_kernels_reduce_sum_i64_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `min(x, axis=k)` with u8 input/output (same-dtype, init = `UINT8_MAX`).
pub fn baracuda_kernels_reduce_min_u8_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_min_u8`.
pub fn baracuda_kernels_reduce_min_u8_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `min(x, axis=k)` with i8 input/output (init = `INT8_MAX`).
pub fn baracuda_kernels_reduce_min_i8_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_min_i8`.
pub fn baracuda_kernels_reduce_min_i8_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `min(x, axis=k)` with u32 input/output (init = `UINT32_MAX`).
pub fn baracuda_kernels_reduce_min_u32_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_min_u32`.
pub fn baracuda_kernels_reduce_min_u32_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `min(x, axis=k)` with i16 input/output (init = `INT16_MAX`).
pub fn baracuda_kernels_reduce_min_i16_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_min_i16`.
pub fn baracuda_kernels_reduce_min_i16_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `min(x, axis=k)` with i32 input/output (init = `INT32_MAX`).
pub fn baracuda_kernels_reduce_min_i32_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_min_i32`.
pub fn baracuda_kernels_reduce_min_i32_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `min(x, axis=k)` with i64 input/output (init = `INT64_MAX`).
pub fn baracuda_kernels_reduce_min_i64_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_min_i64`.
pub fn baracuda_kernels_reduce_min_i64_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `max(x, axis=k)` with u8 input/output (init = `0`).
pub fn baracuda_kernels_reduce_max_u8_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_max_u8`.
pub fn baracuda_kernels_reduce_max_u8_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `max(x, axis=k)` with i8 input/output (init = `INT8_MIN`).
pub fn baracuda_kernels_reduce_max_i8_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_max_i8`.
pub fn baracuda_kernels_reduce_max_i8_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `max(x, axis=k)` with u32 input/output (init = `0`).
pub fn baracuda_kernels_reduce_max_u32_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_max_u32`.
pub fn baracuda_kernels_reduce_max_u32_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `max(x, axis=k)` with i16 input/output (init = `INT16_MIN`).
pub fn baracuda_kernels_reduce_max_i16_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_max_i16`.
pub fn baracuda_kernels_reduce_max_i16_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `max(x, axis=k)` with i32 input/output (init = `INT32_MIN`).
pub fn baracuda_kernels_reduce_max_i32_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_max_i32`.
pub fn baracuda_kernels_reduce_max_i32_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `max(x, axis=k)` with i64 input/output (init = `INT64_MIN`).
pub fn baracuda_kernels_reduce_max_i64_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_max_i64`.
pub fn baracuda_kernels_reduce_max_i64_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `prod(x, axis=k)` with u8 input/output (wider u64 accumulator,
/// wrap-on-overflow narrow on store).
pub fn baracuda_kernels_reduce_prod_u8_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_prod_u8`.
pub fn baracuda_kernels_reduce_prod_u8_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `prod(x, axis=k)` with i8 input/output (wider i64 accumulator).
pub fn baracuda_kernels_reduce_prod_i8_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_prod_i8`.
pub fn baracuda_kernels_reduce_prod_i8_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `prod(x, axis=k)` with u32 input/output (wider u64 accumulator).
pub fn baracuda_kernels_reduce_prod_u32_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_prod_u32`.
pub fn baracuda_kernels_reduce_prod_u32_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `prod(x, axis=k)` with i16 input/output (wider i64 accumulator).
pub fn baracuda_kernels_reduce_prod_i16_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_prod_i16`.
pub fn baracuda_kernels_reduce_prod_i16_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `prod(x, axis=k)` with i32 input/output (wider i64 accumulator).
pub fn baracuda_kernels_reduce_prod_i32_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_prod_i32`.
pub fn baracuda_kernels_reduce_prod_i32_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `prod(x, axis=k)` with i64 input/output. Modulo-2^64 wrap.
pub fn baracuda_kernels_reduce_prod_i64_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_prod_i64`.
pub fn baracuda_kernels_reduce_prod_i64_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// Norm2 reduction along one axis, f32. `y = sqrt(sum(x*x))` —
/// shares the simple-reduce parameter shape.
pub fn baracuda_kernels_reduce_norm2_f32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_norm2_f32`.
pub fn baracuda_kernels_reduce_norm2_f32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Norm2 reduction along one axis, f16 (f32-detour functor + sqrt).
pub fn baracuda_kernels_reduce_norm2_f16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_norm2_f16`.
pub fn baracuda_kernels_reduce_norm2_f16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Norm2 reduction along one axis, bf16 (f32-detour functor + sqrt).
pub fn baracuda_kernels_reduce_norm2_bf16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_norm2_bf16`.
pub fn baracuda_kernels_reduce_norm2_bf16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Norm2 reduction along one axis, f64.
pub fn baracuda_kernels_reduce_norm2_f64_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_norm2_f64`.
pub fn baracuda_kernels_reduce_norm2_f64_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// LogSumExp reduction along one axis, f32 — numerically stable
/// two-pass max-then-sum-exp. Shares the simple-reduce parameter
/// shape so the Rust dispatcher can reach it through the same FFI
/// signature; the kernel internally performs two passes over the
/// reduce axis.
pub fn baracuda_kernels_reduce_logsumexp_f32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_reduce_logsumexp_f32`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_reduce_logsumexp_f32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// LogSumExp reduction along one axis, f16 (f32-detour throughout).
pub fn baracuda_kernels_reduce_logsumexp_f16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_reduce_logsumexp_f16`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_reduce_logsumexp_f16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// LogSumExp reduction along one axis, bf16 (f32-detour throughout).
pub fn baracuda_kernels_reduce_logsumexp_bf16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_reduce_logsumexp_bf16`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_reduce_logsumexp_bf16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// LogSumExp reduction along one axis, f64.
pub fn baracuda_kernels_reduce_logsumexp_f64_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_reduce_logsumexp_f64`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_reduce_logsumexp_f64_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ============================================================================
// Scans (Category F) — length-preserving prefix operators along a
// single axis. ABI mirrors reduce-axis but adds a `reverse` flag and
// uses the full input shape (no axis collapse in the output).
// ============================================================================
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Inclusive prefix sum (`cumsum`) along a single axis, f32.
/// `reverse != 0` flips the scan direction.
pub fn baracuda_kernels_scan_cumsum_f32_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
scan_axis: i32,
scan_extent: i32,
scan_stride_x: i64,
reverse: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `scan_cumsum_f32`.
pub fn baracuda_kernels_scan_cumsum_f32_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
scan_axis: i32, scan_extent: i32, scan_stride_x: i64, reverse: i32,
x: *const c_void, y: *const c_void,
) -> i32;
/// Cumsum, f16. f32-detour accumulator inside the kernel.
pub fn baracuda_kernels_scan_cumsum_f16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
scan_axis: i32,
scan_extent: i32,
scan_stride_x: i64,
reverse: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `scan_cumsum_f16`.
pub fn baracuda_kernels_scan_cumsum_f16_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
scan_axis: i32, scan_extent: i32, scan_stride_x: i64, reverse: i32,
x: *const c_void, y: *const c_void,
) -> i32;
/// Cumsum, bf16.
pub fn baracuda_kernels_scan_cumsum_bf16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
scan_axis: i32,
scan_extent: i32,
scan_stride_x: i64,
reverse: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `scan_cumsum_bf16`.
pub fn baracuda_kernels_scan_cumsum_bf16_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
scan_axis: i32, scan_extent: i32, scan_stride_x: i64, reverse: i32,
x: *const c_void, y: *const c_void,
) -> i32;
/// Cumsum, f64.
pub fn baracuda_kernels_scan_cumsum_f64_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
scan_axis: i32,
scan_extent: i32,
scan_stride_x: i64,
reverse: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `scan_cumsum_f64`.
pub fn baracuda_kernels_scan_cumsum_f64_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
scan_axis: i32, scan_extent: i32, scan_stride_x: i64, reverse: i32,
x: *const c_void, y: *const c_void,
) -> i32;
/// Cumprod (inclusive prefix product), f32. Same ABI as cumsum.
pub fn baracuda_kernels_scan_cumprod_f32_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
scan_axis: i32,
scan_extent: i32,
scan_stride_x: i64,
reverse: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `scan_cumprod_f32`.
pub fn baracuda_kernels_scan_cumprod_f32_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
scan_axis: i32, scan_extent: i32, scan_stride_x: i64, reverse: i32,
x: *const c_void, y: *const c_void,
) -> i32;
/// Cumprod, f16. f32-detour accumulator.
pub fn baracuda_kernels_scan_cumprod_f16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
scan_axis: i32,
scan_extent: i32,
scan_stride_x: i64,
reverse: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `scan_cumprod_f16`.
pub fn baracuda_kernels_scan_cumprod_f16_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
scan_axis: i32, scan_extent: i32, scan_stride_x: i64, reverse: i32,
x: *const c_void, y: *const c_void,
) -> i32;
/// Cumprod, bf16.
pub fn baracuda_kernels_scan_cumprod_bf16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
scan_axis: i32,
scan_extent: i32,
scan_stride_x: i64,
reverse: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `scan_cumprod_bf16`.
pub fn baracuda_kernels_scan_cumprod_bf16_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
scan_axis: i32, scan_extent: i32, scan_stride_x: i64, reverse: i32,
x: *const c_void, y: *const c_void,
) -> i32;
/// Cumprod, f64.
pub fn baracuda_kernels_scan_cumprod_f64_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
scan_axis: i32,
scan_extent: i32,
scan_stride_x: i64,
reverse: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `scan_cumprod_f64`.
pub fn baracuda_kernels_scan_cumprod_f64_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
scan_axis: i32, scan_extent: i32, scan_stride_x: i64, reverse: i32,
x: *const c_void, y: *const c_void,
) -> i32;
/// Cummax (inclusive prefix running max), f32.
pub fn baracuda_kernels_scan_cummax_f32_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
scan_axis: i32,
scan_extent: i32,
scan_stride_x: i64,
reverse: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `scan_cummax_f32`.
pub fn baracuda_kernels_scan_cummax_f32_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
scan_axis: i32, scan_extent: i32, scan_stride_x: i64, reverse: i32,
x: *const c_void, y: *const c_void,
) -> i32;
/// Cummax, f16.
pub fn baracuda_kernels_scan_cummax_f16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
scan_axis: i32,
scan_extent: i32,
scan_stride_x: i64,
reverse: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `scan_cummax_f16`.
pub fn baracuda_kernels_scan_cummax_f16_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
scan_axis: i32, scan_extent: i32, scan_stride_x: i64, reverse: i32,
x: *const c_void, y: *const c_void,
) -> i32;
/// Cummax, bf16.
pub fn baracuda_kernels_scan_cummax_bf16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
scan_axis: i32,
scan_extent: i32,
scan_stride_x: i64,
reverse: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `scan_cummax_bf16`.
pub fn baracuda_kernels_scan_cummax_bf16_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
scan_axis: i32, scan_extent: i32, scan_stride_x: i64, reverse: i32,
x: *const c_void, y: *const c_void,
) -> i32;
/// Cummax, f64.
pub fn baracuda_kernels_scan_cummax_f64_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
scan_axis: i32,
scan_extent: i32,
scan_stride_x: i64,
reverse: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `scan_cummax_f64`.
pub fn baracuda_kernels_scan_cummax_f64_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
scan_axis: i32, scan_extent: i32, scan_stride_x: i64, reverse: i32,
x: *const c_void, y: *const c_void,
) -> i32;
/// Cummin (inclusive prefix running min), f32.
pub fn baracuda_kernels_scan_cummin_f32_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
scan_axis: i32,
scan_extent: i32,
scan_stride_x: i64,
reverse: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `scan_cummin_f32`.
pub fn baracuda_kernels_scan_cummin_f32_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
scan_axis: i32, scan_extent: i32, scan_stride_x: i64, reverse: i32,
x: *const c_void, y: *const c_void,
) -> i32;
/// Cummin, f16.
pub fn baracuda_kernels_scan_cummin_f16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
scan_axis: i32,
scan_extent: i32,
scan_stride_x: i64,
reverse: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `scan_cummin_f16`.
pub fn baracuda_kernels_scan_cummin_f16_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
scan_axis: i32, scan_extent: i32, scan_stride_x: i64, reverse: i32,
x: *const c_void, y: *const c_void,
) -> i32;
/// Cummin, bf16.
pub fn baracuda_kernels_scan_cummin_bf16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
scan_axis: i32,
scan_extent: i32,
scan_stride_x: i64,
reverse: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `scan_cummin_bf16`.
pub fn baracuda_kernels_scan_cummin_bf16_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
scan_axis: i32, scan_extent: i32, scan_stride_x: i64, reverse: i32,
x: *const c_void, y: *const c_void,
) -> i32;
/// Cummin, f64.
pub fn baracuda_kernels_scan_cummin_f64_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
scan_axis: i32,
scan_extent: i32,
scan_stride_x: i64,
reverse: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `scan_cummin_f64`.
pub fn baracuda_kernels_scan_cummin_f64_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
scan_axis: i32, scan_extent: i32, scan_stride_x: i64, reverse: i32,
x: *const c_void, y: *const c_void,
) -> i32;
/// Cumprod backward, f32. Per-cell suffix accumulator of
/// `dy[i] * y[i] / x[j]`. Caller must ensure x has no zeros along
/// the scan axis.
pub fn baracuda_kernels_scan_cumprod_backward_f32_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
scan_axis: i32,
scan_extent: i32,
reverse: i32,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `scan_cumprod_backward_f32`.
pub fn baracuda_kernels_scan_cumprod_backward_f32_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_dy: *const i64, stride_x: *const i64, stride_y: *const i64, stride_dx: *const i64,
scan_axis: i32, scan_extent: i32, reverse: i32,
dy: *const c_void, x: *const c_void, y: *const c_void, dx: *const c_void,
) -> i32;
/// Cumprod backward, f16. f32-detour accumulator.
pub fn baracuda_kernels_scan_cumprod_backward_f16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
scan_axis: i32,
scan_extent: i32,
reverse: i32,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `scan_cumprod_backward_f16`.
pub fn baracuda_kernels_scan_cumprod_backward_f16_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_dy: *const i64, stride_x: *const i64, stride_y: *const i64, stride_dx: *const i64,
scan_axis: i32, scan_extent: i32, reverse: i32,
dy: *const c_void, x: *const c_void, y: *const c_void, dx: *const c_void,
) -> i32;
/// Cumprod backward, bf16.
pub fn baracuda_kernels_scan_cumprod_backward_bf16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
scan_axis: i32,
scan_extent: i32,
reverse: i32,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `scan_cumprod_backward_bf16`.
pub fn baracuda_kernels_scan_cumprod_backward_bf16_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_dy: *const i64, stride_x: *const i64, stride_y: *const i64, stride_dx: *const i64,
scan_axis: i32, scan_extent: i32, reverse: i32,
dy: *const c_void, x: *const c_void, y: *const c_void, dx: *const c_void,
) -> i32;
/// Cumprod backward, f64.
pub fn baracuda_kernels_scan_cumprod_backward_f64_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
scan_axis: i32,
scan_extent: i32,
reverse: i32,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `scan_cumprod_backward_f64`.
pub fn baracuda_kernels_scan_cumprod_backward_f64_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_dy: *const i64, stride_x: *const i64, stride_y: *const i64, stride_dx: *const i64,
scan_axis: i32, scan_extent: i32, reverse: i32,
dy: *const c_void, x: *const c_void, y: *const c_void, dx: *const c_void,
) -> i32;
/// Cummax backward, f32. Walks the forward scan tracking
/// first-occurrence argmax; gradient flows to the source position.
pub fn baracuda_kernels_scan_cummax_backward_f32_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_dx: *const i64,
scan_axis: i32,
scan_extent: i32,
reverse: i32,
dy: *const c_void,
x: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `scan_cummax_backward_f32`.
pub fn baracuda_kernels_scan_cummax_backward_f32_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_dy: *const i64, stride_x: *const i64, stride_dx: *const i64,
scan_axis: i32, scan_extent: i32, reverse: i32,
dy: *const c_void, x: *const c_void, dx: *const c_void,
) -> i32;
/// Cummax backward, f16.
pub fn baracuda_kernels_scan_cummax_backward_f16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_dx: *const i64,
scan_axis: i32,
scan_extent: i32,
reverse: i32,
dy: *const c_void,
x: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `scan_cummax_backward_f16`.
pub fn baracuda_kernels_scan_cummax_backward_f16_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_dy: *const i64, stride_x: *const i64, stride_dx: *const i64,
scan_axis: i32, scan_extent: i32, reverse: i32,
dy: *const c_void, x: *const c_void, dx: *const c_void,
) -> i32;
/// Cummax backward, bf16.
pub fn baracuda_kernels_scan_cummax_backward_bf16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_dx: *const i64,
scan_axis: i32,
scan_extent: i32,
reverse: i32,
dy: *const c_void,
x: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `scan_cummax_backward_bf16`.
pub fn baracuda_kernels_scan_cummax_backward_bf16_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_dy: *const i64, stride_x: *const i64, stride_dx: *const i64,
scan_axis: i32, scan_extent: i32, reverse: i32,
dy: *const c_void, x: *const c_void, dx: *const c_void,
) -> i32;
/// Cummax backward, f64.
pub fn baracuda_kernels_scan_cummax_backward_f64_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_dx: *const i64,
scan_axis: i32,
scan_extent: i32,
reverse: i32,
dy: *const c_void,
x: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `scan_cummax_backward_f64`.
pub fn baracuda_kernels_scan_cummax_backward_f64_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_dy: *const i64, stride_x: *const i64, stride_dx: *const i64,
scan_axis: i32, scan_extent: i32, reverse: i32,
dy: *const c_void, x: *const c_void, dx: *const c_void,
) -> i32;
/// Cummin backward, f32. Same kernel shape as Cummax BW with
/// `<` instead of `>` for the tie-tracking comparison.
pub fn baracuda_kernels_scan_cummin_backward_f32_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_dx: *const i64,
scan_axis: i32,
scan_extent: i32,
reverse: i32,
dy: *const c_void,
x: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `scan_cummin_backward_f32`.
pub fn baracuda_kernels_scan_cummin_backward_f32_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_dy: *const i64, stride_x: *const i64, stride_dx: *const i64,
scan_axis: i32, scan_extent: i32, reverse: i32,
dy: *const c_void, x: *const c_void, dx: *const c_void,
) -> i32;
/// Cummin backward, f16.
pub fn baracuda_kernels_scan_cummin_backward_f16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_dx: *const i64,
scan_axis: i32,
scan_extent: i32,
reverse: i32,
dy: *const c_void,
x: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `scan_cummin_backward_f16`.
pub fn baracuda_kernels_scan_cummin_backward_f16_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_dy: *const i64, stride_x: *const i64, stride_dx: *const i64,
scan_axis: i32, scan_extent: i32, reverse: i32,
dy: *const c_void, x: *const c_void, dx: *const c_void,
) -> i32;
/// Cummin backward, bf16.
pub fn baracuda_kernels_scan_cummin_backward_bf16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_dx: *const i64,
scan_axis: i32,
scan_extent: i32,
reverse: i32,
dy: *const c_void,
x: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `scan_cummin_backward_bf16`.
pub fn baracuda_kernels_scan_cummin_backward_bf16_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_dy: *const i64, stride_x: *const i64, stride_dx: *const i64,
scan_axis: i32, scan_extent: i32, reverse: i32,
dy: *const c_void, x: *const c_void, dx: *const c_void,
) -> i32;
/// Cummin backward, f64.
pub fn baracuda_kernels_scan_cummin_backward_f64_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_dx: *const i64,
scan_axis: i32,
scan_extent: i32,
reverse: i32,
dy: *const c_void,
x: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `scan_cummin_backward_f64`.
pub fn baracuda_kernels_scan_cummin_backward_f64_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_dy: *const i64, stride_x: *const i64, stride_dx: *const i64,
scan_axis: i32, scan_extent: i32, reverse: i32,
dy: *const c_void, x: *const c_void, dx: *const c_void,
) -> i32;
/// LogCumsumExp FW, f32. `y[k] = log(Σ_{j ≤ k} exp(x[j]))`
/// (or suffix-LSE when `reverse != 0`). Numerically stable via
/// the online running-max algorithm. Same ABI as cumsum.
pub fn baracuda_kernels_scan_log_cumsum_exp_f32_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
scan_axis: i32,
scan_extent: i32,
scan_stride_x: i64,
reverse: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_scan_log_cumsum_exp_f32_can_implement` (baracuda kernels scan log cumsum exp f32 can implement).
pub fn baracuda_kernels_scan_log_cumsum_exp_f32_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
scan_axis: i32,
scan_extent: i32,
scan_stride_x: i64,
reverse: i32,
x: *const c_void,
y: *const c_void,
) -> i32;
/// LogCumsumExp FW, f16. f32-detour accumulator inside the kernel.
pub fn baracuda_kernels_scan_log_cumsum_exp_f16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
scan_axis: i32,
scan_extent: i32,
scan_stride_x: i64,
reverse: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_scan_log_cumsum_exp_f16_can_implement` (baracuda kernels scan log cumsum exp f16 can implement).
pub fn baracuda_kernels_scan_log_cumsum_exp_f16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
scan_axis: i32,
scan_extent: i32,
scan_stride_x: i64,
reverse: i32,
x: *const c_void,
y: *const c_void,
) -> i32;
/// LogCumsumExp FW, bf16.
pub fn baracuda_kernels_scan_log_cumsum_exp_bf16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
scan_axis: i32,
scan_extent: i32,
scan_stride_x: i64,
reverse: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_scan_log_cumsum_exp_bf16_can_implement` (baracuda kernels scan log cumsum exp bf16 can implement).
pub fn baracuda_kernels_scan_log_cumsum_exp_bf16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
scan_axis: i32,
scan_extent: i32,
scan_stride_x: i64,
reverse: i32,
x: *const c_void,
y: *const c_void,
) -> i32;
/// LogCumsumExp FW, f64.
pub fn baracuda_kernels_scan_log_cumsum_exp_f64_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
scan_axis: i32,
scan_extent: i32,
scan_stride_x: i64,
reverse: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_scan_log_cumsum_exp_f64_can_implement` (baracuda kernels scan log cumsum exp f64 can implement).
pub fn baracuda_kernels_scan_log_cumsum_exp_f64_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
scan_axis: i32,
scan_extent: i32,
scan_stride_x: i64,
reverse: i32,
x: *const c_void,
y: *const c_void,
) -> i32;
/// LogCumsumExp BW, f32. Per-cell accumulator of
/// `Σ dy[i] * exp(x[k] - y[i])` over the FW-direction-dependent
/// `i` range. Needs both saved `x` and saved `y` (same shape since
/// scans are length-preserving). Stable by construction:
/// `x[k] - y[i] ≤ 0` so `exp(.) ∈ [0, 1]`.
pub fn baracuda_kernels_scan_log_cumsum_exp_backward_f32_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
scan_axis: i32,
scan_extent: i32,
reverse: i32,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_scan_log_cumsum_exp_backward_f32_can_implement` (baracuda kernels scan log cumsum exp backward f32 can implement).
pub fn baracuda_kernels_scan_log_cumsum_exp_backward_f32_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
scan_axis: i32,
scan_extent: i32,
reverse: i32,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// LogCumsumExp BW, f16. f32-detour accumulator.
pub fn baracuda_kernels_scan_log_cumsum_exp_backward_f16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
scan_axis: i32,
scan_extent: i32,
reverse: i32,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_scan_log_cumsum_exp_backward_f16_can_implement` (baracuda kernels scan log cumsum exp backward f16 can implement).
pub fn baracuda_kernels_scan_log_cumsum_exp_backward_f16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
scan_axis: i32,
scan_extent: i32,
reverse: i32,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// LogCumsumExp BW, bf16.
pub fn baracuda_kernels_scan_log_cumsum_exp_backward_bf16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
scan_axis: i32,
scan_extent: i32,
reverse: i32,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_scan_log_cumsum_exp_backward_bf16_can_implement` (baracuda kernels scan log cumsum exp backward bf16 can implement).
pub fn baracuda_kernels_scan_log_cumsum_exp_backward_bf16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
scan_axis: i32,
scan_extent: i32,
reverse: i32,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// LogCumsumExp BW, f64.
pub fn baracuda_kernels_scan_log_cumsum_exp_backward_f64_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
scan_axis: i32,
scan_extent: i32,
reverse: i32,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_scan_log_cumsum_exp_backward_f64_can_implement` (baracuda kernels scan log cumsum exp backward f64 can implement).
pub fn baracuda_kernels_scan_log_cumsum_exp_backward_f64_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
scan_axis: i32,
scan_extent: i32,
reverse: i32,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
}
// ============================================================================
// Softmax family (Category H) — length-preserving stable softmax along
// a single axis. Output shape == input shape. FW + BW per dtype.
// ============================================================================
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Softmax FW, f32. `y[k] = exp(x[k] - max) / Σ exp(x[j] - max)`
/// along `softmax_axis`. Numerically stable.
pub fn baracuda_kernels_softmax_f32_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_softmax_f32_can_implement` (baracuda kernels softmax f32 can implement).
pub fn baracuda_kernels_softmax_f32_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Softmax FW, f16. f32 accumulator inside the kernel.
pub fn baracuda_kernels_softmax_f16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_softmax_f16_can_implement` (baracuda kernels softmax f16 can implement).
pub fn baracuda_kernels_softmax_f16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Softmax FW, bf16.
pub fn baracuda_kernels_softmax_bf16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_softmax_bf16_can_implement` (baracuda kernels softmax bf16 can implement).
pub fn baracuda_kernels_softmax_bf16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Softmax FW, f64.
pub fn baracuda_kernels_softmax_f64_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_softmax_f64_can_implement` (baracuda kernels softmax f64 can implement).
pub fn baracuda_kernels_softmax_f64_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Softmax BW, f32. `dx[k] = y[k] * (dy[k] - Σ_j y[j] * dy[j])`.
/// Caller passes the saved forward output `y`.
pub fn baracuda_kernels_softmax_backward_f32_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_softmax_backward_f32_can_implement` (baracuda kernels softmax backward f32 can implement).
pub fn baracuda_kernels_softmax_backward_f32_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// Softmax BW, f16.
pub fn baracuda_kernels_softmax_backward_f16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_softmax_backward_f16_can_implement` (baracuda kernels softmax backward f16 can implement).
pub fn baracuda_kernels_softmax_backward_f16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// Softmax BW, bf16.
pub fn baracuda_kernels_softmax_backward_bf16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_softmax_backward_bf16_can_implement` (baracuda kernels softmax backward bf16 can implement).
pub fn baracuda_kernels_softmax_backward_bf16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// Softmax BW, f64.
pub fn baracuda_kernels_softmax_backward_f64_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_softmax_backward_f64_can_implement` (baracuda kernels softmax backward f64 can implement).
pub fn baracuda_kernels_softmax_backward_f64_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// LogSoftmax FW, f32. `y[k] = (x[k] - max) - log(Σ exp(x[j] - max))`
/// along `softmax_axis`. Numerically stable.
pub fn baracuda_kernels_log_softmax_f32_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_log_softmax_f32_can_implement` (baracuda kernels log softmax f32 can implement).
pub fn baracuda_kernels_log_softmax_f32_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// LogSoftmax FW, f16. f32 accumulator inside the kernel.
pub fn baracuda_kernels_log_softmax_f16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_log_softmax_f16_can_implement` (baracuda kernels log softmax f16 can implement).
pub fn baracuda_kernels_log_softmax_f16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// LogSoftmax FW, bf16.
pub fn baracuda_kernels_log_softmax_bf16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_log_softmax_bf16_can_implement` (baracuda kernels log softmax bf16 can implement).
pub fn baracuda_kernels_log_softmax_bf16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// LogSoftmax FW, f64.
pub fn baracuda_kernels_log_softmax_f64_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_log_softmax_f64_can_implement` (baracuda kernels log softmax f64 can implement).
pub fn baracuda_kernels_log_softmax_f64_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// LogSoftmax BW, f32. `dx[k] = dy[k] - exp(y[k]) * Σ_j dy[j]`.
/// Caller passes the saved forward output `y` (log-softmax values).
pub fn baracuda_kernels_log_softmax_backward_f32_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_log_softmax_backward_f32_can_implement` (baracuda kernels log softmax backward f32 can implement).
pub fn baracuda_kernels_log_softmax_backward_f32_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// LogSoftmax BW, f16.
pub fn baracuda_kernels_log_softmax_backward_f16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_log_softmax_backward_f16_can_implement` (baracuda kernels log softmax backward f16 can implement).
pub fn baracuda_kernels_log_softmax_backward_f16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// LogSoftmax BW, bf16.
pub fn baracuda_kernels_log_softmax_backward_bf16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_log_softmax_backward_bf16_can_implement` (baracuda kernels log softmax backward bf16 can implement).
pub fn baracuda_kernels_log_softmax_backward_bf16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// LogSoftmax BW, f64.
pub fn baracuda_kernels_log_softmax_backward_f64_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_log_softmax_backward_f64_can_implement` (baracuda kernels log softmax backward f64 can implement).
pub fn baracuda_kernels_log_softmax_backward_f64_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// GumbelSoftmax FW, f32. `y = softmax((x + g) / τ)` where
/// `g = -log(-log(u))` and `u` is a caller-supplied cuRAND uniform
/// buffer (one f32 per output cell, dense / contiguous layout).
/// `inv_tau = 1/τ`. `hard != 0` → one-hot at the noisy argmax.
pub fn baracuda_kernels_gumbel_softmax_f32_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
inv_tau: f32,
hard: i32,
x: *const c_void,
u_rand: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gumbel_softmax_f32_can_implement` (baracuda kernels gumbel softmax f32 can implement).
pub fn baracuda_kernels_gumbel_softmax_f32_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
inv_tau: f32,
hard: i32,
x: *const c_void,
u_rand: *const c_void,
y: *const c_void,
) -> i32;
/// GumbelSoftmax FW, f16.
pub fn baracuda_kernels_gumbel_softmax_f16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
inv_tau: f32,
hard: i32,
x: *const c_void,
u_rand: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gumbel_softmax_f16_can_implement` (baracuda kernels gumbel softmax f16 can implement).
pub fn baracuda_kernels_gumbel_softmax_f16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
inv_tau: f32,
hard: i32,
x: *const c_void,
u_rand: *const c_void,
y: *const c_void,
) -> i32;
/// GumbelSoftmax FW, bf16.
pub fn baracuda_kernels_gumbel_softmax_bf16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
inv_tau: f32,
hard: i32,
x: *const c_void,
u_rand: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gumbel_softmax_bf16_can_implement` (baracuda kernels gumbel softmax bf16 can implement).
pub fn baracuda_kernels_gumbel_softmax_bf16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
inv_tau: f32,
hard: i32,
x: *const c_void,
u_rand: *const c_void,
y: *const c_void,
) -> i32;
/// GumbelSoftmax FW, f64.
pub fn baracuda_kernels_gumbel_softmax_f64_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
inv_tau: f32,
hard: i32,
x: *const c_void,
u_rand: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gumbel_softmax_f64_can_implement` (baracuda kernels gumbel softmax f64 can implement).
pub fn baracuda_kernels_gumbel_softmax_f64_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
inv_tau: f32,
hard: i32,
x: *const c_void,
u_rand: *const c_void,
y: *const c_void,
) -> i32;
/// Sparsemax FW, f32. `y = ProjSimplex(x)` via threshold τ found
/// after sorting the row descending. Row extent limited to 64.
pub fn baracuda_kernels_sparsemax_f32_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_sparsemax_f32_can_implement` (baracuda kernels sparsemax f32 can implement).
pub fn baracuda_kernels_sparsemax_f32_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Sparsemax FW, f16.
pub fn baracuda_kernels_sparsemax_f16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_sparsemax_f16_can_implement` (baracuda kernels sparsemax f16 can implement).
pub fn baracuda_kernels_sparsemax_f16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Sparsemax FW, bf16.
pub fn baracuda_kernels_sparsemax_bf16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_sparsemax_bf16_can_implement` (baracuda kernels sparsemax bf16 can implement).
pub fn baracuda_kernels_sparsemax_bf16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Sparsemax FW, f64.
pub fn baracuda_kernels_sparsemax_f64_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_sparsemax_f64_can_implement` (baracuda kernels sparsemax f64 can implement).
pub fn baracuda_kernels_sparsemax_f64_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Sparsemax BW, f32. `dx[i] = dy[i] - sum_dy_active / n_active` for
/// active positions (`y > 0`), `0` elsewhere.
pub fn baracuda_kernels_sparsemax_backward_f32_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_sparsemax_backward_f32_can_implement` (baracuda kernels sparsemax backward f32 can implement).
pub fn baracuda_kernels_sparsemax_backward_f32_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// Sparsemax BW, f16.
pub fn baracuda_kernels_sparsemax_backward_f16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_sparsemax_backward_f16_can_implement` (baracuda kernels sparsemax backward f16 can implement).
pub fn baracuda_kernels_sparsemax_backward_f16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// Sparsemax BW, bf16.
pub fn baracuda_kernels_sparsemax_backward_bf16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_sparsemax_backward_bf16_can_implement` (baracuda kernels sparsemax backward bf16 can implement).
pub fn baracuda_kernels_sparsemax_backward_bf16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// Sparsemax BW, f64.
pub fn baracuda_kernels_sparsemax_backward_f64_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_sparsemax_backward_f64_can_implement` (baracuda kernels sparsemax backward f64 can implement).
pub fn baracuda_kernels_sparsemax_backward_f64_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
}
// ============================================================================
// Loss family (Category R) — MSE / NLL / CrossEntropy / BCE / KLDiv
// (FW + BW × 4 FP dtypes). Per-cell or per-row kernel emits to a
// workspace buffer; a single-block deterministic tree reduction collapses
// to the final scalar for Mean / Sum modes. For None mode the per-cell
// kernel writes directly to out (no reduction). Reduction modes: 0=None,
// 1=Mean, 2=Sum.
// ============================================================================
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// MSE FW, f32. `(pred - target)²` per-cell; mean/sum reduce to scalar.
/// Workspace: `numel * sizeof(T)` bytes for Mean/Sum; unused for None.
pub fn baracuda_kernels_loss_mse_f32_run(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// MSE FW `_can_implement`, f32. Host-side validator (no launch).
pub fn baracuda_kernels_loss_mse_f32_can_implement(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// MSE FW, f16.
pub fn baracuda_kernels_loss_mse_f16_run(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// MSE FW `_can_implement`, f16.
pub fn baracuda_kernels_loss_mse_f16_can_implement(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// MSE FW, bf16.
pub fn baracuda_kernels_loss_mse_bf16_run(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// MSE FW `_can_implement`, bf16.
pub fn baracuda_kernels_loss_mse_bf16_can_implement(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// MSE FW, f64.
pub fn baracuda_kernels_loss_mse_f64_run(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// MSE FW `_can_implement`, f64.
pub fn baracuda_kernels_loss_mse_f64_can_implement(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// MSE BW, f32. `dpred = 2·(pred - target) · scale`.
pub fn baracuda_kernels_loss_mse_backward_f32_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// MSE BW `_can_implement`, f32.
pub fn baracuda_kernels_loss_mse_backward_f32_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *const c_void,
) -> i32;
/// MSE BW, f16.
pub fn baracuda_kernels_loss_mse_backward_f16_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// MSE BW `_can_implement`, f16.
pub fn baracuda_kernels_loss_mse_backward_f16_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *const c_void,
) -> i32;
/// MSE BW, bf16.
pub fn baracuda_kernels_loss_mse_backward_bf16_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// MSE BW `_can_implement`, bf16.
pub fn baracuda_kernels_loss_mse_backward_bf16_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *const c_void,
) -> i32;
/// MSE BW, f64.
pub fn baracuda_kernels_loss_mse_backward_f64_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// MSE BW `_can_implement`, f64.
pub fn baracuda_kernels_loss_mse_backward_f64_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *const c_void,
) -> i32;
/// BCE FW, f32. `-(t·log(p) + (1-t)·log(1-p))` per-cell, then reduce.
/// Caller ensures pred ∈ (0, 1).
pub fn baracuda_kernels_loss_bce_f32_run(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// BCE FW `_can_implement`, f32.
pub fn baracuda_kernels_loss_bce_f32_can_implement(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// BCE FW, f16.
pub fn baracuda_kernels_loss_bce_f16_run(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// BCE FW `_can_implement`, f16.
pub fn baracuda_kernels_loss_bce_f16_can_implement(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// BCE FW, bf16.
pub fn baracuda_kernels_loss_bce_bf16_run(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// BCE FW `_can_implement`, bf16.
pub fn baracuda_kernels_loss_bce_bf16_can_implement(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// BCE FW, f64.
pub fn baracuda_kernels_loss_bce_f64_run(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// BCE FW `_can_implement`, f64.
pub fn baracuda_kernels_loss_bce_f64_can_implement(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// BCE BW, f32. `dpred = (pred - target) / (pred·(1-pred)) · scale`.
pub fn baracuda_kernels_loss_bce_backward_f32_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// BCE BW `_can_implement`, f32.
pub fn baracuda_kernels_loss_bce_backward_f32_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *const c_void,
) -> i32;
/// BCE BW, f16.
pub fn baracuda_kernels_loss_bce_backward_f16_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// BCE BW `_can_implement`, f16.
pub fn baracuda_kernels_loss_bce_backward_f16_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *const c_void,
) -> i32;
/// BCE BW, bf16.
pub fn baracuda_kernels_loss_bce_backward_bf16_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// BCE BW `_can_implement`, bf16.
pub fn baracuda_kernels_loss_bce_backward_bf16_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *const c_void,
) -> i32;
/// BCE BW, f64.
pub fn baracuda_kernels_loss_bce_backward_f64_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// BCE BW `_can_implement`, f64.
pub fn baracuda_kernels_loss_bce_backward_f64_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *const c_void,
) -> i32;
/// KLDiv FW, f32. `target·(log(target) - input)` per-cell. PyTorch
/// convention: input is already log-prob.
pub fn baracuda_kernels_loss_kl_div_f32_run(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// KLDiv FW `_can_implement`, f32.
pub fn baracuda_kernels_loss_kl_div_f32_can_implement(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// KLDiv FW, f16.
pub fn baracuda_kernels_loss_kl_div_f16_run(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// KLDiv FW `_can_implement`, f16.
pub fn baracuda_kernels_loss_kl_div_f16_can_implement(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// KLDiv FW, bf16.
pub fn baracuda_kernels_loss_kl_div_bf16_run(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// KLDiv FW `_can_implement`, bf16.
pub fn baracuda_kernels_loss_kl_div_bf16_can_implement(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// KLDiv FW, f64.
pub fn baracuda_kernels_loss_kl_div_f64_run(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// KLDiv FW `_can_implement`, f64.
pub fn baracuda_kernels_loss_kl_div_f64_can_implement(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// KLDiv BW, f32. `dinput = -target · scale`.
pub fn baracuda_kernels_loss_kl_div_backward_f32_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
target: *const c_void,
dy: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// KLDiv BW `_can_implement`, f32.
pub fn baracuda_kernels_loss_kl_div_backward_f32_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
target: *const c_void,
dy: *const c_void,
dinput: *const c_void,
) -> i32;
/// KLDiv BW, f16.
pub fn baracuda_kernels_loss_kl_div_backward_f16_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
target: *const c_void,
dy: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// KLDiv BW `_can_implement`, f16.
pub fn baracuda_kernels_loss_kl_div_backward_f16_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
target: *const c_void,
dy: *const c_void,
dinput: *const c_void,
) -> i32;
/// KLDiv BW, bf16.
pub fn baracuda_kernels_loss_kl_div_backward_bf16_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
target: *const c_void,
dy: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// KLDiv BW `_can_implement`, bf16.
pub fn baracuda_kernels_loss_kl_div_backward_bf16_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
target: *const c_void,
dy: *const c_void,
dinput: *const c_void,
) -> i32;
/// KLDiv BW, f64.
pub fn baracuda_kernels_loss_kl_div_backward_f64_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
target: *const c_void,
dy: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// KLDiv BW `_can_implement`, f64.
pub fn baracuda_kernels_loss_kl_div_backward_f64_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
target: *const c_void,
dy: *const c_void,
dinput: *const c_void,
) -> i32;
/// NLL FW, f32. `-input[i, target[i]]` per row. Heterogeneous-dtype:
/// input is `T`, target is `i64`. `row_stride_input` is the i64 stride
/// between adjacent rows of `input` (must equal `class_extent` for
/// contiguous input).
pub fn baracuda_kernels_loss_nll_f32_run(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
reduction_mode: i32,
input: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// NLL FW `_can_implement`, f32.
pub fn baracuda_kernels_loss_nll_f32_can_implement(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
reduction_mode: i32,
input: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// NLL FW, f16.
pub fn baracuda_kernels_loss_nll_f16_run(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
reduction_mode: i32,
input: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// NLL FW `_can_implement`, f16.
pub fn baracuda_kernels_loss_nll_f16_can_implement(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
reduction_mode: i32,
input: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// NLL FW, bf16.
pub fn baracuda_kernels_loss_nll_bf16_run(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
reduction_mode: i32,
input: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// NLL FW `_can_implement`, bf16.
pub fn baracuda_kernels_loss_nll_bf16_can_implement(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
reduction_mode: i32,
input: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// NLL FW, f64.
pub fn baracuda_kernels_loss_nll_f64_run(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
reduction_mode: i32,
input: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// NLL FW `_can_implement`, f64.
pub fn baracuda_kernels_loss_nll_f64_can_implement(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
reduction_mode: i32,
input: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// NLL BW, f32. Pre-zeros `dinput` (size `dinput_numel · sizeof(T)`),
/// then writes `dinput[i, target[i]] = -dy_or_scale`.
pub fn baracuda_kernels_loss_nll_backward_f32_run(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
dinput_numel: i64,
reduction_mode: i32,
scale_scalar: f32,
dy: *const c_void,
target: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// NLL BW `_can_implement`, f32.
pub fn baracuda_kernels_loss_nll_backward_f32_can_implement(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
dinput_numel: i64,
reduction_mode: i32,
scale_scalar: f32,
dy: *const c_void,
target: *const c_void,
dinput: *const c_void,
) -> i32;
/// NLL BW, f16.
pub fn baracuda_kernels_loss_nll_backward_f16_run(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
dinput_numel: i64,
reduction_mode: i32,
scale_scalar: f32,
dy: *const c_void,
target: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// NLL BW `_can_implement`, f16.
pub fn baracuda_kernels_loss_nll_backward_f16_can_implement(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
dinput_numel: i64,
reduction_mode: i32,
scale_scalar: f32,
dy: *const c_void,
target: *const c_void,
dinput: *const c_void,
) -> i32;
/// NLL BW, bf16.
pub fn baracuda_kernels_loss_nll_backward_bf16_run(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
dinput_numel: i64,
reduction_mode: i32,
scale_scalar: f32,
dy: *const c_void,
target: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// NLL BW `_can_implement`, bf16.
pub fn baracuda_kernels_loss_nll_backward_bf16_can_implement(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
dinput_numel: i64,
reduction_mode: i32,
scale_scalar: f32,
dy: *const c_void,
target: *const c_void,
dinput: *const c_void,
) -> i32;
/// NLL BW, f64.
pub fn baracuda_kernels_loss_nll_backward_f64_run(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
dinput_numel: i64,
reduction_mode: i32,
scale_scalar: f32,
dy: *const c_void,
target: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// NLL BW `_can_implement`, f64.
pub fn baracuda_kernels_loss_nll_backward_f64_can_implement(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
dinput_numel: i64,
reduction_mode: i32,
scale_scalar: f32,
dy: *const c_void,
target: *const c_void,
dinput: *const c_void,
) -> i32;
/// CrossEntropy FW, f32. Fused LogSoftmax + NLL. Numerically stable
/// per-row two-pass max subtraction.
pub fn baracuda_kernels_loss_cross_entropy_f32_run(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
reduction_mode: i32,
input: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// CrossEntropy FW `_can_implement`, f32.
pub fn baracuda_kernels_loss_cross_entropy_f32_can_implement(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
reduction_mode: i32,
input: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// CrossEntropy FW, f16.
pub fn baracuda_kernels_loss_cross_entropy_f16_run(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
reduction_mode: i32,
input: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// CrossEntropy FW `_can_implement`, f16.
pub fn baracuda_kernels_loss_cross_entropy_f16_can_implement(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
reduction_mode: i32,
input: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// CrossEntropy FW, bf16.
pub fn baracuda_kernels_loss_cross_entropy_bf16_run(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
reduction_mode: i32,
input: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// CrossEntropy FW `_can_implement`, bf16.
pub fn baracuda_kernels_loss_cross_entropy_bf16_can_implement(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
reduction_mode: i32,
input: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// CrossEntropy FW, f64.
pub fn baracuda_kernels_loss_cross_entropy_f64_run(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
reduction_mode: i32,
input: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// CrossEntropy FW `_can_implement`, f64.
pub fn baracuda_kernels_loss_cross_entropy_f64_can_implement(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
reduction_mode: i32,
input: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// CrossEntropy BW, f32. `dinput[i, c] = (softmax(input)[i, c] - 1{c==t[i]}) · scale`.
pub fn baracuda_kernels_loss_cross_entropy_backward_f32_run(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
reduction_mode: i32,
scale_scalar: f32,
input: *const c_void,
target: *const c_void,
dy: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// CrossEntropy BW `_can_implement`, f32.
pub fn baracuda_kernels_loss_cross_entropy_backward_f32_can_implement(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
reduction_mode: i32,
scale_scalar: f32,
input: *const c_void,
target: *const c_void,
dy: *const c_void,
dinput: *const c_void,
) -> i32;
/// CrossEntropy BW, f16.
pub fn baracuda_kernels_loss_cross_entropy_backward_f16_run(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
reduction_mode: i32,
scale_scalar: f32,
input: *const c_void,
target: *const c_void,
dy: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_cross_entropy_backward_f16_can_implement` (baracuda kernels loss cross entropy backward f16 can implement).
pub fn baracuda_kernels_loss_cross_entropy_backward_f16_can_implement(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
reduction_mode: i32,
scale_scalar: f32,
input: *const c_void,
target: *const c_void,
dy: *const c_void,
dinput: *const c_void,
) -> i32;
/// CrossEntropy BW, bf16.
pub fn baracuda_kernels_loss_cross_entropy_backward_bf16_run(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
reduction_mode: i32,
scale_scalar: f32,
input: *const c_void,
target: *const c_void,
dy: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_cross_entropy_backward_bf16_can_implement` (baracuda kernels loss cross entropy backward bf16 can implement).
pub fn baracuda_kernels_loss_cross_entropy_backward_bf16_can_implement(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
reduction_mode: i32,
scale_scalar: f32,
input: *const c_void,
target: *const c_void,
dy: *const c_void,
dinput: *const c_void,
) -> i32;
/// CrossEntropy BW, f64.
pub fn baracuda_kernels_loss_cross_entropy_backward_f64_run(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
reduction_mode: i32,
scale_scalar: f32,
input: *const c_void,
target: *const c_void,
dy: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_cross_entropy_backward_f64_can_implement` (baracuda kernels loss cross entropy backward f64 can implement).
pub fn baracuda_kernels_loss_cross_entropy_backward_f64_can_implement(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
reduction_mode: i32,
scale_scalar: f32,
input: *const c_void,
target: *const c_void,
dy: *const c_void,
dinput: *const c_void,
) -> i32;
}
// ============================================================================
// Phase 47 — Fused Linear Cross-Entropy (Liger-Kernel algorithm port).
//
// Math/algorithm credit: LinkedIn Liger-Kernel BSD-2-Clause (clean-room CUDA
// re-implementation; no Liger source vendored). 17 new symbols total:
// 4 per-row × 4 dtypes + 4 cast × 4 dtypes + 4 scalar-finalize × 4 dtypes
// + 4 inplace-scale × 4 dtypes + 1 count-non-ignore = 17.
// ============================================================================
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// FLCE per-row fused step, f32. Mutates `logits` in place to
/// `grad_logits = (softmax - one_hot) · scale_per_row`; writes
/// per-row `-log_softmax[target]` into `loss_1d` (f32 accumulator).
pub fn baracuda_kernels_loss_flce_per_row_f32_run(
n_rows: i32, v: i32, row_stride: i64, target_ignore: i64,
scale_per_row: f32,
logits: *mut c_void, target: *const c_void, loss_1d: *mut c_void,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_loss_flce_per_row_f32`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_loss_flce_per_row_f32_can_implement(
n_rows: i32,
v: i32,
row_stride: i64,
target_ignore: i64,
scale_per_row: f32,
logits: *const c_void,
target: *const c_void,
loss_1d: *const c_void,
) -> i32;
/// FLCE per-row fused step, f16.
pub fn baracuda_kernels_loss_flce_per_row_f16_run(
n_rows: i32, v: i32, row_stride: i64, target_ignore: i64,
scale_per_row: f32,
logits: *mut c_void, target: *const c_void, loss_1d: *mut c_void,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_loss_flce_per_row_f16`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_loss_flce_per_row_f16_can_implement(
n_rows: i32,
v: i32,
row_stride: i64,
target_ignore: i64,
scale_per_row: f32,
logits: *const c_void,
target: *const c_void,
loss_1d: *const c_void,
) -> i32;
/// FLCE per-row fused step, bf16.
pub fn baracuda_kernels_loss_flce_per_row_bf16_run(
n_rows: i32, v: i32, row_stride: i64, target_ignore: i64,
scale_per_row: f32,
logits: *mut c_void, target: *const c_void, loss_1d: *mut c_void,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_loss_flce_per_row_bf16`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_loss_flce_per_row_bf16_can_implement(
n_rows: i32,
v: i32,
row_stride: i64,
target_ignore: i64,
scale_per_row: f32,
logits: *const c_void,
target: *const c_void,
loss_1d: *const c_void,
) -> i32;
/// FLCE per-row fused step, f64.
pub fn baracuda_kernels_loss_flce_per_row_f64_run(
n_rows: i32, v: i32, row_stride: i64, target_ignore: i64,
scale_per_row: f32,
logits: *mut c_void, target: *const c_void, loss_1d: *mut c_void,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_loss_flce_per_row_f64`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_loss_flce_per_row_f64_can_implement(
n_rows: i32,
v: i32,
row_stride: i64,
target_ignore: i64,
scale_per_row: f32,
logits: *const c_void,
target: *const c_void,
loss_1d: *const c_void,
) -> i32;
/// FLCE per-row cast (None mode finalizer), f32 → f32.
pub fn baracuda_kernels_loss_flce_per_row_cast_f32_run(
n_rows: i64, loss_1d: *const c_void, out: *mut c_void, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_loss_flce_per_row_cast_f32`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_loss_flce_per_row_cast_f32_can_implement(
n_rows: i64,
loss_1d: *const c_void,
out: *const c_void,
) -> i32;
/// FLCE per-row cast, f32 → f16.
pub fn baracuda_kernels_loss_flce_per_row_cast_f16_run(
n_rows: i64, loss_1d: *const c_void, out: *mut c_void, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_loss_flce_per_row_cast_f16`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_loss_flce_per_row_cast_f16_can_implement(
n_rows: i64,
loss_1d: *const c_void,
out: *const c_void,
) -> i32;
/// FLCE per-row cast, f32 → bf16.
pub fn baracuda_kernels_loss_flce_per_row_cast_bf16_run(
n_rows: i64, loss_1d: *const c_void, out: *mut c_void, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_loss_flce_per_row_cast_bf16`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_loss_flce_per_row_cast_bf16_can_implement(
n_rows: i64,
loss_1d: *const c_void,
out: *const c_void,
) -> i32;
/// FLCE per-row cast, f32 → f64.
pub fn baracuda_kernels_loss_flce_per_row_cast_f64_run(
n_rows: i64, loss_1d: *const c_void, out: *mut c_void, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_loss_flce_per_row_cast_f64`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_loss_flce_per_row_cast_f64_can_implement(
n_rows: i64,
loss_1d: *const c_void,
out: *const c_void,
) -> i32;
/// FLCE scalar finalize (Mean/Sum), f32 → f32.
pub fn baracuda_kernels_loss_flce_scalar_finalize_f32_run(
n_rows: i64, denom_inv: f32,
loss_1d: *const c_void, out: *mut c_void, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_loss_flce_scalar_finalize_f32`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_loss_flce_scalar_finalize_f32_can_implement(
n_rows: i64,
denom_inv: f32,
loss_1d: *const c_void,
out: *const c_void,
) -> i32;
/// FLCE scalar finalize, f32 → f16.
pub fn baracuda_kernels_loss_flce_scalar_finalize_f16_run(
n_rows: i64, denom_inv: f32,
loss_1d: *const c_void, out: *mut c_void, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_loss_flce_scalar_finalize_f16`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_loss_flce_scalar_finalize_f16_can_implement(
n_rows: i64,
denom_inv: f32,
loss_1d: *const c_void,
out: *const c_void,
) -> i32;
/// FLCE scalar finalize, f32 → bf16.
pub fn baracuda_kernels_loss_flce_scalar_finalize_bf16_run(
n_rows: i64, denom_inv: f32,
loss_1d: *const c_void, out: *mut c_void, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_loss_flce_scalar_finalize_bf16`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_loss_flce_scalar_finalize_bf16_can_implement(
n_rows: i64,
denom_inv: f32,
loss_1d: *const c_void,
out: *const c_void,
) -> i32;
/// FLCE scalar finalize, f32 → f64.
pub fn baracuda_kernels_loss_flce_scalar_finalize_f64_run(
n_rows: i64, denom_inv: f32,
loss_1d: *const c_void, out: *mut c_void, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_loss_flce_scalar_finalize_f64`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_loss_flce_scalar_finalize_f64_can_implement(
n_rows: i64,
denom_inv: f32,
loss_1d: *const c_void,
out: *const c_void,
) -> i32;
/// FLCE in-place scale, f32. Multiplies `buf` in place by `scalar`.
pub fn baracuda_kernels_loss_flce_inplace_scale_f32_run(
numel: i64, scalar: f32, buf: *mut c_void, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_loss_flce_inplace_scale_f32`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_loss_flce_inplace_scale_f32_can_implement(
numel: i64,
scalar: f32,
buf: *const c_void,
) -> i32;
/// FLCE in-place scale, f16.
pub fn baracuda_kernels_loss_flce_inplace_scale_f16_run(
numel: i64, scalar: f32, buf: *mut c_void, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_loss_flce_inplace_scale_f16`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_loss_flce_inplace_scale_f16_can_implement(
numel: i64,
scalar: f32,
buf: *const c_void,
) -> i32;
/// FLCE in-place scale, bf16.
pub fn baracuda_kernels_loss_flce_inplace_scale_bf16_run(
numel: i64, scalar: f32, buf: *mut c_void, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_loss_flce_inplace_scale_bf16`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_loss_flce_inplace_scale_bf16_can_implement(
numel: i64,
scalar: f32,
buf: *const c_void,
) -> i32;
/// FLCE in-place scale, f64.
pub fn baracuda_kernels_loss_flce_inplace_scale_f64_run(
numel: i64, scalar: f32, buf: *mut c_void, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_loss_flce_inplace_scale_f64`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_loss_flce_inplace_scale_f64_can_implement(
numel: i64,
scalar: f32,
buf: *const c_void,
) -> i32;
/// FLCE count-non-ignore. Single-block tree reduction; writes the
/// `target[i] != ignore_index` count into `count_out[0]` (i64).
pub fn baracuda_kernels_loss_flce_count_non_ignore_run(
bt: i32, ignore_index: i64,
target: *const c_void, count_out: *mut c_void, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_loss_flce_count_non_ignore`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_loss_flce_count_non_ignore_can_implement(
bt: i32,
ignore_index: i64,
target: *const c_void,
count_out: *const c_void,
) -> i32;
}
// ============================================================================
// Milestone 5.2 — Tier-1 losses (L1 / SmoothL1 / Huber / BCEWithLogits /
// PoissonNLL / GaussianNLL / soft-target CrossEntropy). All follow the same
// per-cell-kernel + tree-reduction-finalizer pattern as the original loss
// family; the SmoothL1 / Huber / PoissonNLL / GaussianNLL launchers thread
// the corresponding scalar parameter (β / δ / log_input_flag / eps) through.
// ============================================================================
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// L1 FW, f32. `y = |pred - target|` per-cell; mean/sum reduce to scalar.
pub fn baracuda_kernels_loss_l1_f32_run(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// L1 FW `_can_implement`, f32.
pub fn baracuda_kernels_loss_l1_f32_can_implement(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// L1 FW, f16.
pub fn baracuda_kernels_loss_l1_f16_run(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// L1 FW `_can_implement`, f16.
pub fn baracuda_kernels_loss_l1_f16_can_implement(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// L1 FW, bf16.
pub fn baracuda_kernels_loss_l1_bf16_run(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// L1 FW `_can_implement`, bf16.
pub fn baracuda_kernels_loss_l1_bf16_can_implement(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// L1 FW, f64.
pub fn baracuda_kernels_loss_l1_f64_run(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// L1 FW `_can_implement`, f64.
pub fn baracuda_kernels_loss_l1_f64_can_implement(
numel: i64,
reduction_mode: i32,
pred: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// L1 BW, f32. `dpred = sign(pred - target) · scale`.
pub fn baracuda_kernels_loss_l1_backward_f32_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// L1 BW `_can_implement`, f32.
pub fn baracuda_kernels_loss_l1_backward_f32_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *const c_void,
) -> i32;
/// L1 BW, f16.
pub fn baracuda_kernels_loss_l1_backward_f16_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_l1_backward_f16_can_implement` (baracuda kernels loss l1 backward f16 can implement).
pub fn baracuda_kernels_loss_l1_backward_f16_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *const c_void,
) -> i32;
/// L1 BW, bf16.
pub fn baracuda_kernels_loss_l1_backward_bf16_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_l1_backward_bf16_can_implement` (baracuda kernels loss l1 backward bf16 can implement).
pub fn baracuda_kernels_loss_l1_backward_bf16_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *const c_void,
) -> i32;
/// L1 BW, f64.
pub fn baracuda_kernels_loss_l1_backward_f64_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_l1_backward_f64_can_implement` (baracuda kernels loss l1 backward f64 can implement).
pub fn baracuda_kernels_loss_l1_backward_f64_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *const c_void,
) -> i32;
/// SmoothL1 FW, f32. `param = β`.
pub fn baracuda_kernels_loss_smooth_l1_f32_run(
numel: i64,
reduction_mode: i32,
param: f32,
pred: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_smooth_l1_f32_can_implement` (baracuda kernels loss smooth l1 f32 can implement).
pub fn baracuda_kernels_loss_smooth_l1_f32_can_implement(
numel: i64,
reduction_mode: i32,
param: f32,
pred: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// SmoothL1 FW, f16.
pub fn baracuda_kernels_loss_smooth_l1_f16_run(
numel: i64,
reduction_mode: i32,
param: f32,
pred: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_smooth_l1_f16_can_implement` (baracuda kernels loss smooth l1 f16 can implement).
pub fn baracuda_kernels_loss_smooth_l1_f16_can_implement(
numel: i64,
reduction_mode: i32,
param: f32,
pred: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// SmoothL1 FW, bf16.
pub fn baracuda_kernels_loss_smooth_l1_bf16_run(
numel: i64,
reduction_mode: i32,
param: f32,
pred: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_smooth_l1_bf16_can_implement` (baracuda kernels loss smooth l1 bf16 can implement).
pub fn baracuda_kernels_loss_smooth_l1_bf16_can_implement(
numel: i64,
reduction_mode: i32,
param: f32,
pred: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// SmoothL1 FW, f64.
pub fn baracuda_kernels_loss_smooth_l1_f64_run(
numel: i64,
reduction_mode: i32,
param: f32,
pred: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_smooth_l1_f64_can_implement` (baracuda kernels loss smooth l1 f64 can implement).
pub fn baracuda_kernels_loss_smooth_l1_f64_can_implement(
numel: i64,
reduction_mode: i32,
param: f32,
pred: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// SmoothL1 BW, f32.
pub fn baracuda_kernels_loss_smooth_l1_backward_f32_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
param: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_smooth_l1_backward_f32_can_implement` (baracuda kernels loss smooth l1 backward f32 can implement).
pub fn baracuda_kernels_loss_smooth_l1_backward_f32_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
param: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *const c_void,
) -> i32;
/// SmoothL1 BW, f16.
pub fn baracuda_kernels_loss_smooth_l1_backward_f16_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
param: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_smooth_l1_backward_f16_can_implement` (baracuda kernels loss smooth l1 backward f16 can implement).
pub fn baracuda_kernels_loss_smooth_l1_backward_f16_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
param: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *const c_void,
) -> i32;
/// SmoothL1 BW, bf16.
pub fn baracuda_kernels_loss_smooth_l1_backward_bf16_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
param: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_smooth_l1_backward_bf16_can_implement` (baracuda kernels loss smooth l1 backward bf16 can implement).
pub fn baracuda_kernels_loss_smooth_l1_backward_bf16_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
param: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *const c_void,
) -> i32;
/// SmoothL1 BW, f64.
pub fn baracuda_kernels_loss_smooth_l1_backward_f64_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
param: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_smooth_l1_backward_f64_can_implement` (baracuda kernels loss smooth l1 backward f64 can implement).
pub fn baracuda_kernels_loss_smooth_l1_backward_f64_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
param: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *const c_void,
) -> i32;
/// Huber FW, f32. `param = δ`.
pub fn baracuda_kernels_loss_huber_f32_run(
numel: i64,
reduction_mode: i32,
param: f32,
pred: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_huber_f32_can_implement` (baracuda kernels loss huber f32 can implement).
pub fn baracuda_kernels_loss_huber_f32_can_implement(
numel: i64,
reduction_mode: i32,
param: f32,
pred: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// Huber FW, f16.
pub fn baracuda_kernels_loss_huber_f16_run(
numel: i64,
reduction_mode: i32,
param: f32,
pred: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_huber_f16_can_implement` (baracuda kernels loss huber f16 can implement).
pub fn baracuda_kernels_loss_huber_f16_can_implement(
numel: i64,
reduction_mode: i32,
param: f32,
pred: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// Huber FW, bf16.
pub fn baracuda_kernels_loss_huber_bf16_run(
numel: i64,
reduction_mode: i32,
param: f32,
pred: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_huber_bf16_can_implement` (baracuda kernels loss huber bf16 can implement).
pub fn baracuda_kernels_loss_huber_bf16_can_implement(
numel: i64,
reduction_mode: i32,
param: f32,
pred: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// Huber FW, f64.
pub fn baracuda_kernels_loss_huber_f64_run(
numel: i64,
reduction_mode: i32,
param: f32,
pred: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_huber_f64_can_implement` (baracuda kernels loss huber f64 can implement).
pub fn baracuda_kernels_loss_huber_f64_can_implement(
numel: i64,
reduction_mode: i32,
param: f32,
pred: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// Huber BW, f32.
pub fn baracuda_kernels_loss_huber_backward_f32_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
param: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_huber_backward_f32_can_implement` (baracuda kernels loss huber backward f32 can implement).
pub fn baracuda_kernels_loss_huber_backward_f32_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
param: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *const c_void,
) -> i32;
/// Huber BW, f16.
pub fn baracuda_kernels_loss_huber_backward_f16_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
param: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_huber_backward_f16_can_implement` (baracuda kernels loss huber backward f16 can implement).
pub fn baracuda_kernels_loss_huber_backward_f16_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
param: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *const c_void,
) -> i32;
/// Huber BW, bf16.
pub fn baracuda_kernels_loss_huber_backward_bf16_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
param: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_huber_backward_bf16_can_implement` (baracuda kernels loss huber backward bf16 can implement).
pub fn baracuda_kernels_loss_huber_backward_bf16_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
param: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *const c_void,
) -> i32;
/// Huber BW, f64.
pub fn baracuda_kernels_loss_huber_backward_f64_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
param: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_huber_backward_f64_can_implement` (baracuda kernels loss huber backward f64 can implement).
pub fn baracuda_kernels_loss_huber_backward_f64_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
param: f32,
pred: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *const c_void,
) -> i32;
/// BCEWithLogits FW, f32. Stable BCE for raw logits.
pub fn baracuda_kernels_loss_bce_with_logits_f32_run(
numel: i64,
reduction_mode: i32,
logits: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_bce_with_logits_f32_can_implement` (baracuda kernels loss bce with logits f32 can implement).
pub fn baracuda_kernels_loss_bce_with_logits_f32_can_implement(
numel: i64,
reduction_mode: i32,
logits: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// BCEWithLogits FW, f16.
pub fn baracuda_kernels_loss_bce_with_logits_f16_run(
numel: i64,
reduction_mode: i32,
logits: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_bce_with_logits_f16_can_implement` (baracuda kernels loss bce with logits f16 can implement).
pub fn baracuda_kernels_loss_bce_with_logits_f16_can_implement(
numel: i64,
reduction_mode: i32,
logits: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// BCEWithLogits FW, bf16.
pub fn baracuda_kernels_loss_bce_with_logits_bf16_run(
numel: i64,
reduction_mode: i32,
logits: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_bce_with_logits_bf16_can_implement` (baracuda kernels loss bce with logits bf16 can implement).
pub fn baracuda_kernels_loss_bce_with_logits_bf16_can_implement(
numel: i64,
reduction_mode: i32,
logits: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// BCEWithLogits FW, f64.
pub fn baracuda_kernels_loss_bce_with_logits_f64_run(
numel: i64,
reduction_mode: i32,
logits: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_bce_with_logits_f64_can_implement` (baracuda kernels loss bce with logits f64 can implement).
pub fn baracuda_kernels_loss_bce_with_logits_f64_can_implement(
numel: i64,
reduction_mode: i32,
logits: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// BCEWithLogits BW, f32. `dlogits = (sigmoid(x) - target) · scale`.
pub fn baracuda_kernels_loss_bce_with_logits_backward_f32_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
logits: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_bce_with_logits_backward_f32_can_implement` (baracuda kernels loss bce with logits backward f32 can implement).
pub fn baracuda_kernels_loss_bce_with_logits_backward_f32_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
logits: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *const c_void,
) -> i32;
/// BCEWithLogits BW, f16.
pub fn baracuda_kernels_loss_bce_with_logits_backward_f16_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
logits: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_bce_with_logits_backward_f16_can_implement` (baracuda kernels loss bce with logits backward f16 can implement).
pub fn baracuda_kernels_loss_bce_with_logits_backward_f16_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
logits: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *const c_void,
) -> i32;
/// BCEWithLogits BW, bf16.
pub fn baracuda_kernels_loss_bce_with_logits_backward_bf16_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
logits: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_bce_with_logits_backward_bf16_can_implement` (baracuda kernels loss bce with logits backward bf16 can implement).
pub fn baracuda_kernels_loss_bce_with_logits_backward_bf16_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
logits: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *const c_void,
) -> i32;
/// BCEWithLogits BW, f64.
pub fn baracuda_kernels_loss_bce_with_logits_backward_f64_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
logits: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_bce_with_logits_backward_f64_can_implement` (baracuda kernels loss bce with logits backward f64 can implement).
pub fn baracuda_kernels_loss_bce_with_logits_backward_f64_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
logits: *const c_void,
target: *const c_void,
dy: *const c_void,
dpred: *const c_void,
) -> i32;
/// PoissonNLL FW, f32. `log_input_flag` 0/1.
pub fn baracuda_kernels_loss_poisson_nll_f32_run(
numel: i64,
reduction_mode: i32,
log_input_flag: i32,
input: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_poisson_nll_f32_can_implement` (baracuda kernels loss poisson nll f32 can implement).
pub fn baracuda_kernels_loss_poisson_nll_f32_can_implement(
numel: i64,
reduction_mode: i32,
log_input_flag: i32,
input: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// PoissonNLL FW, f16.
pub fn baracuda_kernels_loss_poisson_nll_f16_run(
numel: i64,
reduction_mode: i32,
log_input_flag: i32,
input: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_poisson_nll_f16_can_implement` (baracuda kernels loss poisson nll f16 can implement).
pub fn baracuda_kernels_loss_poisson_nll_f16_can_implement(
numel: i64,
reduction_mode: i32,
log_input_flag: i32,
input: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// PoissonNLL FW, bf16.
pub fn baracuda_kernels_loss_poisson_nll_bf16_run(
numel: i64,
reduction_mode: i32,
log_input_flag: i32,
input: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_poisson_nll_bf16_can_implement` (baracuda kernels loss poisson nll bf16 can implement).
pub fn baracuda_kernels_loss_poisson_nll_bf16_can_implement(
numel: i64,
reduction_mode: i32,
log_input_flag: i32,
input: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// PoissonNLL FW, f64.
pub fn baracuda_kernels_loss_poisson_nll_f64_run(
numel: i64,
reduction_mode: i32,
log_input_flag: i32,
input: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_poisson_nll_f64_can_implement` (baracuda kernels loss poisson nll f64 can implement).
pub fn baracuda_kernels_loss_poisson_nll_f64_can_implement(
numel: i64,
reduction_mode: i32,
log_input_flag: i32,
input: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// PoissonNLL BW, f32.
pub fn baracuda_kernels_loss_poisson_nll_backward_f32_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
log_input_flag: i32,
input: *const c_void,
target: *const c_void,
dy: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_poisson_nll_backward_f32_can_implement` (baracuda kernels loss poisson nll backward f32 can implement).
pub fn baracuda_kernels_loss_poisson_nll_backward_f32_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
log_input_flag: i32,
input: *const c_void,
target: *const c_void,
dy: *const c_void,
dinput: *const c_void,
) -> i32;
/// PoissonNLL BW, f16.
pub fn baracuda_kernels_loss_poisson_nll_backward_f16_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
log_input_flag: i32,
input: *const c_void,
target: *const c_void,
dy: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_poisson_nll_backward_f16_can_implement` (baracuda kernels loss poisson nll backward f16 can implement).
pub fn baracuda_kernels_loss_poisson_nll_backward_f16_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
log_input_flag: i32,
input: *const c_void,
target: *const c_void,
dy: *const c_void,
dinput: *const c_void,
) -> i32;
/// PoissonNLL BW, bf16.
pub fn baracuda_kernels_loss_poisson_nll_backward_bf16_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
log_input_flag: i32,
input: *const c_void,
target: *const c_void,
dy: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_poisson_nll_backward_bf16_can_implement` (baracuda kernels loss poisson nll backward bf16 can implement).
pub fn baracuda_kernels_loss_poisson_nll_backward_bf16_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
log_input_flag: i32,
input: *const c_void,
target: *const c_void,
dy: *const c_void,
dinput: *const c_void,
) -> i32;
/// PoissonNLL BW, f64.
pub fn baracuda_kernels_loss_poisson_nll_backward_f64_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
log_input_flag: i32,
input: *const c_void,
target: *const c_void,
dy: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_poisson_nll_backward_f64_can_implement` (baracuda kernels loss poisson nll backward f64 can implement).
pub fn baracuda_kernels_loss_poisson_nll_backward_f64_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
log_input_flag: i32,
input: *const c_void,
target: *const c_void,
dy: *const c_void,
dinput: *const c_void,
) -> i32;
/// GaussianNLL FW, f32. 3-tensor input (input, target, var).
pub fn baracuda_kernels_loss_gaussian_nll_f32_run(
numel: i64,
reduction_mode: i32,
eps: f32,
input: *const c_void,
target: *const c_void,
var: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_gaussian_nll_f32_can_implement` (baracuda kernels loss gaussian nll f32 can implement).
pub fn baracuda_kernels_loss_gaussian_nll_f32_can_implement(
numel: i64,
reduction_mode: i32,
eps: f32,
input: *const c_void,
target: *const c_void,
var: *const c_void,
out: *const c_void,
) -> i32;
/// GaussianNLL FW, f16.
pub fn baracuda_kernels_loss_gaussian_nll_f16_run(
numel: i64,
reduction_mode: i32,
eps: f32,
input: *const c_void,
target: *const c_void,
var: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_gaussian_nll_f16_can_implement` (baracuda kernels loss gaussian nll f16 can implement).
pub fn baracuda_kernels_loss_gaussian_nll_f16_can_implement(
numel: i64,
reduction_mode: i32,
eps: f32,
input: *const c_void,
target: *const c_void,
var: *const c_void,
out: *const c_void,
) -> i32;
/// GaussianNLL FW, bf16.
pub fn baracuda_kernels_loss_gaussian_nll_bf16_run(
numel: i64,
reduction_mode: i32,
eps: f32,
input: *const c_void,
target: *const c_void,
var: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_gaussian_nll_bf16_can_implement` (baracuda kernels loss gaussian nll bf16 can implement).
pub fn baracuda_kernels_loss_gaussian_nll_bf16_can_implement(
numel: i64,
reduction_mode: i32,
eps: f32,
input: *const c_void,
target: *const c_void,
var: *const c_void,
out: *const c_void,
) -> i32;
/// GaussianNLL FW, f64.
pub fn baracuda_kernels_loss_gaussian_nll_f64_run(
numel: i64,
reduction_mode: i32,
eps: f32,
input: *const c_void,
target: *const c_void,
var: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_gaussian_nll_f64_can_implement` (baracuda kernels loss gaussian nll f64 can implement).
pub fn baracuda_kernels_loss_gaussian_nll_f64_can_implement(
numel: i64,
reduction_mode: i32,
eps: f32,
input: *const c_void,
target: *const c_void,
var: *const c_void,
out: *const c_void,
) -> i32;
/// GaussianNLL BW, f32.
pub fn baracuda_kernels_loss_gaussian_nll_backward_f32_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
eps: f32,
input: *const c_void,
target: *const c_void,
var: *const c_void,
dy: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_gaussian_nll_backward_f32_can_implement` (baracuda kernels loss gaussian nll backward f32 can implement).
pub fn baracuda_kernels_loss_gaussian_nll_backward_f32_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
eps: f32,
input: *const c_void,
target: *const c_void,
var: *const c_void,
dy: *const c_void,
dinput: *const c_void,
) -> i32;
/// GaussianNLL BW, f16.
pub fn baracuda_kernels_loss_gaussian_nll_backward_f16_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
eps: f32,
input: *const c_void,
target: *const c_void,
var: *const c_void,
dy: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_gaussian_nll_backward_f16_can_implement` (baracuda kernels loss gaussian nll backward f16 can implement).
pub fn baracuda_kernels_loss_gaussian_nll_backward_f16_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
eps: f32,
input: *const c_void,
target: *const c_void,
var: *const c_void,
dy: *const c_void,
dinput: *const c_void,
) -> i32;
/// GaussianNLL BW, bf16.
pub fn baracuda_kernels_loss_gaussian_nll_backward_bf16_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
eps: f32,
input: *const c_void,
target: *const c_void,
var: *const c_void,
dy: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_gaussian_nll_backward_bf16_can_implement` (baracuda kernels loss gaussian nll backward bf16 can implement).
pub fn baracuda_kernels_loss_gaussian_nll_backward_bf16_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
eps: f32,
input: *const c_void,
target: *const c_void,
var: *const c_void,
dy: *const c_void,
dinput: *const c_void,
) -> i32;
/// GaussianNLL BW, f64.
pub fn baracuda_kernels_loss_gaussian_nll_backward_f64_run(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
eps: f32,
input: *const c_void,
target: *const c_void,
var: *const c_void,
dy: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_gaussian_nll_backward_f64_can_implement` (baracuda kernels loss gaussian nll backward f64 can implement).
pub fn baracuda_kernels_loss_gaussian_nll_backward_f64_can_implement(
numel: i64,
reduction_mode: i32,
scale_scalar: f32,
eps: f32,
input: *const c_void,
target: *const c_void,
var: *const c_void,
dy: *const c_void,
dinput: *const c_void,
) -> i32;
/// Soft-target CrossEntropy FW, f32. Target is `T`-typed prob tensor.
pub fn baracuda_kernels_loss_cross_entropy_soft_f32_run(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
row_stride_target: i64,
reduction_mode: i32,
input: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_cross_entropy_soft_f32_can_implement` (baracuda kernels loss cross entropy soft f32 can implement).
pub fn baracuda_kernels_loss_cross_entropy_soft_f32_can_implement(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
row_stride_target: i64,
reduction_mode: i32,
input: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// Soft-target CrossEntropy FW, f16.
pub fn baracuda_kernels_loss_cross_entropy_soft_f16_run(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
row_stride_target: i64,
reduction_mode: i32,
input: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_cross_entropy_soft_f16_can_implement` (baracuda kernels loss cross entropy soft f16 can implement).
pub fn baracuda_kernels_loss_cross_entropy_soft_f16_can_implement(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
row_stride_target: i64,
reduction_mode: i32,
input: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// Soft-target CrossEntropy FW, bf16.
pub fn baracuda_kernels_loss_cross_entropy_soft_bf16_run(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
row_stride_target: i64,
reduction_mode: i32,
input: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_cross_entropy_soft_bf16_can_implement` (baracuda kernels loss cross entropy soft bf16 can implement).
pub fn baracuda_kernels_loss_cross_entropy_soft_bf16_can_implement(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
row_stride_target: i64,
reduction_mode: i32,
input: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// Soft-target CrossEntropy FW, f64.
pub fn baracuda_kernels_loss_cross_entropy_soft_f64_run(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
row_stride_target: i64,
reduction_mode: i32,
input: *const c_void,
target: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_cross_entropy_soft_f64_can_implement` (baracuda kernels loss cross entropy soft f64 can implement).
pub fn baracuda_kernels_loss_cross_entropy_soft_f64_can_implement(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
row_stride_target: i64,
reduction_mode: i32,
input: *const c_void,
target: *const c_void,
out: *const c_void,
) -> i32;
/// Soft-target CrossEntropy BW, f32.
pub fn baracuda_kernels_loss_cross_entropy_soft_backward_f32_run(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
row_stride_target: i64,
reduction_mode: i32,
scale_scalar: f32,
input: *const c_void,
target: *const c_void,
dy: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_cross_entropy_soft_backward_f32_can_implement` (baracuda kernels loss cross entropy soft backward f32 can implement).
pub fn baracuda_kernels_loss_cross_entropy_soft_backward_f32_can_implement(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
row_stride_target: i64,
reduction_mode: i32,
scale_scalar: f32,
input: *const c_void,
target: *const c_void,
dy: *const c_void,
dinput: *const c_void,
) -> i32;
/// MarginRanking FW, f32. ABI: `(numel, reduction_mode, margin,
/// x1, x2, t, out, workspace, workspace_bytes, stream)`.
pub fn baracuda_kernels_loss_margin_ranking_f32_run(
numel: i64, reduction_mode: i32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_margin_ranking_f32_can_implement` (baracuda kernels loss margin ranking f32 can implement).
pub fn baracuda_kernels_loss_margin_ranking_f32_can_implement(
numel: i64, reduction_mode: i32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, out: *const c_void,
) -> i32;
/// MarginRanking FW, f16.
pub fn baracuda_kernels_loss_margin_ranking_f16_run(
numel: i64, reduction_mode: i32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_margin_ranking_f16_can_implement` (baracuda kernels loss margin ranking f16 can implement).
pub fn baracuda_kernels_loss_margin_ranking_f16_can_implement(
numel: i64, reduction_mode: i32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, out: *const c_void,
) -> i32;
/// MarginRanking FW, bf16.
pub fn baracuda_kernels_loss_margin_ranking_bf16_run(
numel: i64, reduction_mode: i32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_margin_ranking_bf16_can_implement` (baracuda kernels loss margin ranking bf16 can implement).
pub fn baracuda_kernels_loss_margin_ranking_bf16_can_implement(
numel: i64, reduction_mode: i32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, out: *const c_void,
) -> i32;
/// MarginRanking FW, f64.
pub fn baracuda_kernels_loss_margin_ranking_f64_run(
numel: i64, reduction_mode: i32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_margin_ranking_f64_can_implement` (baracuda kernels loss margin ranking f64 can implement).
pub fn baracuda_kernels_loss_margin_ranking_f64_can_implement(
numel: i64, reduction_mode: i32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, out: *const c_void,
) -> i32;
/// MarginRanking BW, f32. ABI: `(numel, reduction_mode, scale, margin,
/// x1, x2, t, dy, dx1, dx2, workspace, workspace_bytes, stream)`.
pub fn baracuda_kernels_loss_margin_ranking_backward_f32_run(
numel: i64, reduction_mode: i32, scale_scalar: f32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, dy: *const c_void,
dx1: *mut c_void, dx2: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_margin_ranking_backward_f32_can_implement` (baracuda kernels loss margin ranking backward f32 can implement).
pub fn baracuda_kernels_loss_margin_ranking_backward_f32_can_implement(
numel: i64, reduction_mode: i32, scale_scalar: f32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, dy: *const c_void,
dx1: *const c_void, dx2: *const c_void,
) -> i32;
/// MarginRanking BW, f16.
pub fn baracuda_kernels_loss_margin_ranking_backward_f16_run(
numel: i64, reduction_mode: i32, scale_scalar: f32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, dy: *const c_void,
dx1: *mut c_void, dx2: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_margin_ranking_backward_f16_can_implement` (baracuda kernels loss margin ranking backward f16 can implement).
pub fn baracuda_kernels_loss_margin_ranking_backward_f16_can_implement(
numel: i64, reduction_mode: i32, scale_scalar: f32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, dy: *const c_void,
dx1: *const c_void, dx2: *const c_void,
) -> i32;
/// MarginRanking BW, bf16.
pub fn baracuda_kernels_loss_margin_ranking_backward_bf16_run(
numel: i64, reduction_mode: i32, scale_scalar: f32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, dy: *const c_void,
dx1: *mut c_void, dx2: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_margin_ranking_backward_bf16_can_implement` (baracuda kernels loss margin ranking backward bf16 can implement).
pub fn baracuda_kernels_loss_margin_ranking_backward_bf16_can_implement(
numel: i64, reduction_mode: i32, scale_scalar: f32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, dy: *const c_void,
dx1: *const c_void, dx2: *const c_void,
) -> i32;
/// MarginRanking BW, f64.
pub fn baracuda_kernels_loss_margin_ranking_backward_f64_run(
numel: i64, reduction_mode: i32, scale_scalar: f32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, dy: *const c_void,
dx1: *mut c_void, dx2: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_margin_ranking_backward_f64_can_implement` (baracuda kernels loss margin ranking backward f64 can implement).
pub fn baracuda_kernels_loss_margin_ranking_backward_f64_can_implement(
numel: i64, reduction_mode: i32, scale_scalar: f32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, dy: *const c_void,
dx1: *const c_void, dx2: *const c_void,
) -> i32;
/// HingeEmbedding FW, f32. ABI: `(numel, reduction_mode, margin,
/// input, target_i64, out, workspace, workspace_bytes, stream)`.
pub fn baracuda_kernels_loss_hinge_embedding_f32_run(
numel: i64, reduction_mode: i32, margin: f32,
input: *const c_void, target: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_hinge_embedding_f32_can_implement` (baracuda kernels loss hinge embedding f32 can implement).
pub fn baracuda_kernels_loss_hinge_embedding_f32_can_implement(
numel: i64, reduction_mode: i32, margin: f32,
input: *const c_void, target: *const c_void, out: *const c_void,
) -> i32;
/// HingeEmbedding FW, f16.
pub fn baracuda_kernels_loss_hinge_embedding_f16_run(
numel: i64, reduction_mode: i32, margin: f32,
input: *const c_void, target: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_hinge_embedding_f16_can_implement` (baracuda kernels loss hinge embedding f16 can implement).
pub fn baracuda_kernels_loss_hinge_embedding_f16_can_implement(
numel: i64, reduction_mode: i32, margin: f32,
input: *const c_void, target: *const c_void, out: *const c_void,
) -> i32;
/// HingeEmbedding FW, bf16.
pub fn baracuda_kernels_loss_hinge_embedding_bf16_run(
numel: i64, reduction_mode: i32, margin: f32,
input: *const c_void, target: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_hinge_embedding_bf16_can_implement` (baracuda kernels loss hinge embedding bf16 can implement).
pub fn baracuda_kernels_loss_hinge_embedding_bf16_can_implement(
numel: i64, reduction_mode: i32, margin: f32,
input: *const c_void, target: *const c_void, out: *const c_void,
) -> i32;
/// HingeEmbedding FW, f64.
pub fn baracuda_kernels_loss_hinge_embedding_f64_run(
numel: i64, reduction_mode: i32, margin: f32,
input: *const c_void, target: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_hinge_embedding_f64_can_implement` (baracuda kernels loss hinge embedding f64 can implement).
pub fn baracuda_kernels_loss_hinge_embedding_f64_can_implement(
numel: i64, reduction_mode: i32, margin: f32,
input: *const c_void, target: *const c_void, out: *const c_void,
) -> i32;
/// HingeEmbedding BW, f32.
pub fn baracuda_kernels_loss_hinge_embedding_backward_f32_run(
numel: i64, reduction_mode: i32, scale_scalar: f32, margin: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_hinge_embedding_backward_f32_can_implement` (baracuda kernels loss hinge embedding backward f32 can implement).
pub fn baracuda_kernels_loss_hinge_embedding_backward_f32_can_implement(
numel: i64, reduction_mode: i32, scale_scalar: f32, margin: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *const c_void,
) -> i32;
/// HingeEmbedding BW, f16.
pub fn baracuda_kernels_loss_hinge_embedding_backward_f16_run(
numel: i64, reduction_mode: i32, scale_scalar: f32, margin: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_hinge_embedding_backward_f16_can_implement` (baracuda kernels loss hinge embedding backward f16 can implement).
pub fn baracuda_kernels_loss_hinge_embedding_backward_f16_can_implement(
numel: i64, reduction_mode: i32, scale_scalar: f32, margin: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *const c_void,
) -> i32;
/// HingeEmbedding BW, bf16.
pub fn baracuda_kernels_loss_hinge_embedding_backward_bf16_run(
numel: i64, reduction_mode: i32, scale_scalar: f32, margin: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_hinge_embedding_backward_bf16_can_implement` (baracuda kernels loss hinge embedding backward bf16 can implement).
pub fn baracuda_kernels_loss_hinge_embedding_backward_bf16_can_implement(
numel: i64, reduction_mode: i32, scale_scalar: f32, margin: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *const c_void,
) -> i32;
/// HingeEmbedding BW, f64.
pub fn baracuda_kernels_loss_hinge_embedding_backward_f64_run(
numel: i64, reduction_mode: i32, scale_scalar: f32, margin: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_hinge_embedding_backward_f64_can_implement` (baracuda kernels loss hinge embedding backward f64 can implement).
pub fn baracuda_kernels_loss_hinge_embedding_backward_f64_can_implement(
numel: i64, reduction_mode: i32, scale_scalar: f32, margin: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *const c_void,
) -> i32;
/// CosineEmbedding FW (per-row). ABI: `(n_rows, d_extent, row_stride_x,
/// reduction_mode, margin, x1, x2, t, out, workspace, workspace_bytes, stream)`.
pub fn baracuda_kernels_loss_cosine_embedding_f32_run(
n_rows: i64, d_extent: i32, row_stride_x: i64,
reduction_mode: i32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_cosine_embedding_f32_can_implement` (baracuda kernels loss cosine embedding f32 can implement).
pub fn baracuda_kernels_loss_cosine_embedding_f32_can_implement(
n_rows: i64, d_extent: i32, row_stride_x: i64,
reduction_mode: i32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, out: *const c_void,
) -> i32;
/// CosineEmbedding FW, f16.
pub fn baracuda_kernels_loss_cosine_embedding_f16_run(
n_rows: i64, d_extent: i32, row_stride_x: i64,
reduction_mode: i32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_cosine_embedding_f16_can_implement` (baracuda kernels loss cosine embedding f16 can implement).
pub fn baracuda_kernels_loss_cosine_embedding_f16_can_implement(
n_rows: i64, d_extent: i32, row_stride_x: i64,
reduction_mode: i32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, out: *const c_void,
) -> i32;
/// CosineEmbedding FW, bf16.
pub fn baracuda_kernels_loss_cosine_embedding_bf16_run(
n_rows: i64, d_extent: i32, row_stride_x: i64,
reduction_mode: i32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_cosine_embedding_bf16_can_implement` (baracuda kernels loss cosine embedding bf16 can implement).
pub fn baracuda_kernels_loss_cosine_embedding_bf16_can_implement(
n_rows: i64, d_extent: i32, row_stride_x: i64,
reduction_mode: i32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, out: *const c_void,
) -> i32;
/// CosineEmbedding FW, f64.
pub fn baracuda_kernels_loss_cosine_embedding_f64_run(
n_rows: i64, d_extent: i32, row_stride_x: i64,
reduction_mode: i32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_cosine_embedding_f64_can_implement` (baracuda kernels loss cosine embedding f64 can implement).
pub fn baracuda_kernels_loss_cosine_embedding_f64_can_implement(
n_rows: i64, d_extent: i32, row_stride_x: i64,
reduction_mode: i32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, out: *const c_void,
) -> i32;
/// CosineEmbedding BW.
pub fn baracuda_kernels_loss_cosine_embedding_backward_f32_run(
n_rows: i64, d_extent: i32, row_stride_x: i64,
reduction_mode: i32, scale_scalar: f32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, dy: *const c_void,
dx1: *mut c_void, dx2: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_cosine_embedding_backward_f32_can_implement` (baracuda kernels loss cosine embedding backward f32 can implement).
pub fn baracuda_kernels_loss_cosine_embedding_backward_f32_can_implement(
n_rows: i64, d_extent: i32, row_stride_x: i64,
reduction_mode: i32, scale_scalar: f32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, dy: *const c_void,
dx1: *const c_void, dx2: *const c_void,
) -> i32;
/// CosineEmbedding BW, f16.
pub fn baracuda_kernels_loss_cosine_embedding_backward_f16_run(
n_rows: i64, d_extent: i32, row_stride_x: i64,
reduction_mode: i32, scale_scalar: f32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, dy: *const c_void,
dx1: *mut c_void, dx2: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_cosine_embedding_backward_f16_can_implement` (baracuda kernels loss cosine embedding backward f16 can implement).
pub fn baracuda_kernels_loss_cosine_embedding_backward_f16_can_implement(
n_rows: i64, d_extent: i32, row_stride_x: i64,
reduction_mode: i32, scale_scalar: f32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, dy: *const c_void,
dx1: *const c_void, dx2: *const c_void,
) -> i32;
/// CosineEmbedding BW, bf16.
pub fn baracuda_kernels_loss_cosine_embedding_backward_bf16_run(
n_rows: i64, d_extent: i32, row_stride_x: i64,
reduction_mode: i32, scale_scalar: f32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, dy: *const c_void,
dx1: *mut c_void, dx2: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_cosine_embedding_backward_bf16_can_implement` (baracuda kernels loss cosine embedding backward bf16 can implement).
pub fn baracuda_kernels_loss_cosine_embedding_backward_bf16_can_implement(
n_rows: i64, d_extent: i32, row_stride_x: i64,
reduction_mode: i32, scale_scalar: f32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, dy: *const c_void,
dx1: *const c_void, dx2: *const c_void,
) -> i32;
/// CosineEmbedding BW, f64.
pub fn baracuda_kernels_loss_cosine_embedding_backward_f64_run(
n_rows: i64, d_extent: i32, row_stride_x: i64,
reduction_mode: i32, scale_scalar: f32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, dy: *const c_void,
dx1: *mut c_void, dx2: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_cosine_embedding_backward_f64_can_implement` (baracuda kernels loss cosine embedding backward f64 can implement).
pub fn baracuda_kernels_loss_cosine_embedding_backward_f64_can_implement(
n_rows: i64, d_extent: i32, row_stride_x: i64,
reduction_mode: i32, scale_scalar: f32, margin: f32,
x1: *const c_void, x2: *const c_void, t: *const c_void, dy: *const c_void,
dx1: *const c_void, dx2: *const c_void,
) -> i32;
/// TripletMargin FW (per-row). ABI: `(n_rows, d_extent, row_stride,
/// reduction_mode, margin, p_norm, a, p, n, out, workspace, workspace_bytes, stream)`.
pub fn baracuda_kernels_loss_triplet_margin_f32_run(
n_rows: i64, d_extent: i32, row_stride: i64,
reduction_mode: i32, margin: f32, p_norm: f32,
a: *const c_void, p_tensor: *const c_void, n_tensor: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_triplet_margin_f32_can_implement` (baracuda kernels loss triplet margin f32 can implement).
pub fn baracuda_kernels_loss_triplet_margin_f32_can_implement(
n_rows: i64, d_extent: i32, row_stride: i64,
reduction_mode: i32, margin: f32, p_norm: f32,
a: *const c_void, p_tensor: *const c_void, n_tensor: *const c_void, out: *const c_void,
) -> i32;
/// TripletMargin FW, f16.
pub fn baracuda_kernels_loss_triplet_margin_f16_run(
n_rows: i64, d_extent: i32, row_stride: i64,
reduction_mode: i32, margin: f32, p_norm: f32,
a: *const c_void, p_tensor: *const c_void, n_tensor: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_triplet_margin_f16_can_implement` (baracuda kernels loss triplet margin f16 can implement).
pub fn baracuda_kernels_loss_triplet_margin_f16_can_implement(
n_rows: i64, d_extent: i32, row_stride: i64,
reduction_mode: i32, margin: f32, p_norm: f32,
a: *const c_void, p_tensor: *const c_void, n_tensor: *const c_void, out: *const c_void,
) -> i32;
/// TripletMargin FW, bf16.
pub fn baracuda_kernels_loss_triplet_margin_bf16_run(
n_rows: i64, d_extent: i32, row_stride: i64,
reduction_mode: i32, margin: f32, p_norm: f32,
a: *const c_void, p_tensor: *const c_void, n_tensor: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_triplet_margin_bf16_can_implement` (baracuda kernels loss triplet margin bf16 can implement).
pub fn baracuda_kernels_loss_triplet_margin_bf16_can_implement(
n_rows: i64, d_extent: i32, row_stride: i64,
reduction_mode: i32, margin: f32, p_norm: f32,
a: *const c_void, p_tensor: *const c_void, n_tensor: *const c_void, out: *const c_void,
) -> i32;
/// TripletMargin FW, f64.
pub fn baracuda_kernels_loss_triplet_margin_f64_run(
n_rows: i64, d_extent: i32, row_stride: i64,
reduction_mode: i32, margin: f32, p_norm: f32,
a: *const c_void, p_tensor: *const c_void, n_tensor: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_triplet_margin_f64_can_implement` (baracuda kernels loss triplet margin f64 can implement).
pub fn baracuda_kernels_loss_triplet_margin_f64_can_implement(
n_rows: i64, d_extent: i32, row_stride: i64,
reduction_mode: i32, margin: f32, p_norm: f32,
a: *const c_void, p_tensor: *const c_void, n_tensor: *const c_void, out: *const c_void,
) -> i32;
/// TripletMargin BW.
pub fn baracuda_kernels_loss_triplet_margin_backward_f32_run(
n_rows: i64, d_extent: i32, row_stride: i64,
reduction_mode: i32, scale_scalar: f32, margin: f32, p_norm: f32,
a: *const c_void, p_tensor: *const c_void, n_tensor: *const c_void, dy: *const c_void,
da: *mut c_void, dp: *mut c_void, dn: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_triplet_margin_backward_f32_can_implement` (baracuda kernels loss triplet margin backward f32 can implement).
pub fn baracuda_kernels_loss_triplet_margin_backward_f32_can_implement(
n_rows: i64, d_extent: i32, row_stride: i64,
reduction_mode: i32, scale_scalar: f32, margin: f32, p_norm: f32,
a: *const c_void, p_tensor: *const c_void, n_tensor: *const c_void, dy: *const c_void,
da: *const c_void, dp: *const c_void, dn: *const c_void,
) -> i32;
/// TripletMargin BW, f16.
pub fn baracuda_kernels_loss_triplet_margin_backward_f16_run(
n_rows: i64, d_extent: i32, row_stride: i64,
reduction_mode: i32, scale_scalar: f32, margin: f32, p_norm: f32,
a: *const c_void, p_tensor: *const c_void, n_tensor: *const c_void, dy: *const c_void,
da: *mut c_void, dp: *mut c_void, dn: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_triplet_margin_backward_f16_can_implement` (baracuda kernels loss triplet margin backward f16 can implement).
pub fn baracuda_kernels_loss_triplet_margin_backward_f16_can_implement(
n_rows: i64, d_extent: i32, row_stride: i64,
reduction_mode: i32, scale_scalar: f32, margin: f32, p_norm: f32,
a: *const c_void, p_tensor: *const c_void, n_tensor: *const c_void, dy: *const c_void,
da: *const c_void, dp: *const c_void, dn: *const c_void,
) -> i32;
/// TripletMargin BW, bf16.
pub fn baracuda_kernels_loss_triplet_margin_backward_bf16_run(
n_rows: i64, d_extent: i32, row_stride: i64,
reduction_mode: i32, scale_scalar: f32, margin: f32, p_norm: f32,
a: *const c_void, p_tensor: *const c_void, n_tensor: *const c_void, dy: *const c_void,
da: *mut c_void, dp: *mut c_void, dn: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_triplet_margin_backward_bf16_can_implement` (baracuda kernels loss triplet margin backward bf16 can implement).
pub fn baracuda_kernels_loss_triplet_margin_backward_bf16_can_implement(
n_rows: i64, d_extent: i32, row_stride: i64,
reduction_mode: i32, scale_scalar: f32, margin: f32, p_norm: f32,
a: *const c_void, p_tensor: *const c_void, n_tensor: *const c_void, dy: *const c_void,
da: *const c_void, dp: *const c_void, dn: *const c_void,
) -> i32;
/// TripletMargin BW, f64.
pub fn baracuda_kernels_loss_triplet_margin_backward_f64_run(
n_rows: i64, d_extent: i32, row_stride: i64,
reduction_mode: i32, scale_scalar: f32, margin: f32, p_norm: f32,
a: *const c_void, p_tensor: *const c_void, n_tensor: *const c_void, dy: *const c_void,
da: *mut c_void, dp: *mut c_void, dn: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_triplet_margin_backward_f64_can_implement` (baracuda kernels loss triplet margin backward f64 can implement).
pub fn baracuda_kernels_loss_triplet_margin_backward_f64_can_implement(
n_rows: i64, d_extent: i32, row_stride: i64,
reduction_mode: i32, scale_scalar: f32, margin: f32, p_norm: f32,
a: *const c_void, p_tensor: *const c_void, n_tensor: *const c_void, dy: *const c_void,
da: *const c_void, dp: *const c_void, dn: *const c_void,
) -> i32;
/// MultiMargin FW (per-row). ABI: `(n_rows, class_extent, row_stride,
/// reduction_mode, margin, p_norm, input, target_i64, out, workspace, workspace_bytes, stream)`.
pub fn baracuda_kernels_loss_multi_margin_f32_run(
n_rows: i64, class_extent: i32, row_stride: i64,
reduction_mode: i32, margin: f32, p_norm: f32,
input: *const c_void, target: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_multi_margin_f32_can_implement` (baracuda kernels loss multi margin f32 can implement).
pub fn baracuda_kernels_loss_multi_margin_f32_can_implement(
n_rows: i64, class_extent: i32, row_stride: i64,
reduction_mode: i32, margin: f32, p_norm: f32,
input: *const c_void, target: *const c_void, out: *const c_void,
) -> i32;
/// MultiMargin FW, f16.
pub fn baracuda_kernels_loss_multi_margin_f16_run(
n_rows: i64, class_extent: i32, row_stride: i64,
reduction_mode: i32, margin: f32, p_norm: f32,
input: *const c_void, target: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_multi_margin_f16_can_implement` (baracuda kernels loss multi margin f16 can implement).
pub fn baracuda_kernels_loss_multi_margin_f16_can_implement(
n_rows: i64, class_extent: i32, row_stride: i64,
reduction_mode: i32, margin: f32, p_norm: f32,
input: *const c_void, target: *const c_void, out: *const c_void,
) -> i32;
/// MultiMargin FW, bf16.
pub fn baracuda_kernels_loss_multi_margin_bf16_run(
n_rows: i64, class_extent: i32, row_stride: i64,
reduction_mode: i32, margin: f32, p_norm: f32,
input: *const c_void, target: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_multi_margin_bf16_can_implement` (baracuda kernels loss multi margin bf16 can implement).
pub fn baracuda_kernels_loss_multi_margin_bf16_can_implement(
n_rows: i64, class_extent: i32, row_stride: i64,
reduction_mode: i32, margin: f32, p_norm: f32,
input: *const c_void, target: *const c_void, out: *const c_void,
) -> i32;
/// MultiMargin FW, f64.
pub fn baracuda_kernels_loss_multi_margin_f64_run(
n_rows: i64, class_extent: i32, row_stride: i64,
reduction_mode: i32, margin: f32, p_norm: f32,
input: *const c_void, target: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_multi_margin_f64_can_implement` (baracuda kernels loss multi margin f64 can implement).
pub fn baracuda_kernels_loss_multi_margin_f64_can_implement(
n_rows: i64, class_extent: i32, row_stride: i64,
reduction_mode: i32, margin: f32, p_norm: f32,
input: *const c_void, target: *const c_void, out: *const c_void,
) -> i32;
/// MultiMargin BW.
pub fn baracuda_kernels_loss_multi_margin_backward_f32_run(
n_rows: i64, class_extent: i32, row_stride: i64,
reduction_mode: i32, scale_scalar: f32, margin: f32, p_norm: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_multi_margin_backward_f32_can_implement` (baracuda kernels loss multi margin backward f32 can implement).
pub fn baracuda_kernels_loss_multi_margin_backward_f32_can_implement(
n_rows: i64, class_extent: i32, row_stride: i64,
reduction_mode: i32, scale_scalar: f32, margin: f32, p_norm: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *const c_void,
) -> i32;
/// MultiMargin BW, f16.
pub fn baracuda_kernels_loss_multi_margin_backward_f16_run(
n_rows: i64, class_extent: i32, row_stride: i64,
reduction_mode: i32, scale_scalar: f32, margin: f32, p_norm: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_multi_margin_backward_f16_can_implement` (baracuda kernels loss multi margin backward f16 can implement).
pub fn baracuda_kernels_loss_multi_margin_backward_f16_can_implement(
n_rows: i64, class_extent: i32, row_stride: i64,
reduction_mode: i32, scale_scalar: f32, margin: f32, p_norm: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *const c_void,
) -> i32;
/// MultiMargin BW, bf16.
pub fn baracuda_kernels_loss_multi_margin_backward_bf16_run(
n_rows: i64, class_extent: i32, row_stride: i64,
reduction_mode: i32, scale_scalar: f32, margin: f32, p_norm: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_multi_margin_backward_bf16_can_implement` (baracuda kernels loss multi margin backward bf16 can implement).
pub fn baracuda_kernels_loss_multi_margin_backward_bf16_can_implement(
n_rows: i64, class_extent: i32, row_stride: i64,
reduction_mode: i32, scale_scalar: f32, margin: f32, p_norm: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *const c_void,
) -> i32;
/// MultiMargin BW, f64.
pub fn baracuda_kernels_loss_multi_margin_backward_f64_run(
n_rows: i64, class_extent: i32, row_stride: i64,
reduction_mode: i32, scale_scalar: f32, margin: f32, p_norm: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_multi_margin_backward_f64_can_implement` (baracuda kernels loss multi margin backward f64 can implement).
pub fn baracuda_kernels_loss_multi_margin_backward_f64_can_implement(
n_rows: i64, class_extent: i32, row_stride: i64,
reduction_mode: i32, scale_scalar: f32, margin: f32, p_norm: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *const c_void,
) -> i32;
/// MultilabelMargin FW (per-row). ABI: `(n_rows, class_extent,
/// row_stride_in, row_stride_tgt, reduction_mode, input, target_i64,
/// out, workspace, workspace_bytes, stream)`.
pub fn baracuda_kernels_loss_multilabel_margin_f32_run(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32,
input: *const c_void, target: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_multilabel_margin_f32_can_implement` (baracuda kernels loss multilabel margin f32 can implement).
pub fn baracuda_kernels_loss_multilabel_margin_f32_can_implement(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32,
input: *const c_void, target: *const c_void, out: *const c_void,
) -> i32;
/// MultilabelMargin FW, f16.
pub fn baracuda_kernels_loss_multilabel_margin_f16_run(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32,
input: *const c_void, target: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_multilabel_margin_f16_can_implement` (baracuda kernels loss multilabel margin f16 can implement).
pub fn baracuda_kernels_loss_multilabel_margin_f16_can_implement(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32,
input: *const c_void, target: *const c_void, out: *const c_void,
) -> i32;
/// MultilabelMargin FW, bf16.
pub fn baracuda_kernels_loss_multilabel_margin_bf16_run(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32,
input: *const c_void, target: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_multilabel_margin_bf16_can_implement` (baracuda kernels loss multilabel margin bf16 can implement).
pub fn baracuda_kernels_loss_multilabel_margin_bf16_can_implement(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32,
input: *const c_void, target: *const c_void, out: *const c_void,
) -> i32;
/// MultilabelMargin FW, f64.
pub fn baracuda_kernels_loss_multilabel_margin_f64_run(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32,
input: *const c_void, target: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_multilabel_margin_f64_can_implement` (baracuda kernels loss multilabel margin f64 can implement).
pub fn baracuda_kernels_loss_multilabel_margin_f64_can_implement(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32,
input: *const c_void, target: *const c_void, out: *const c_void,
) -> i32;
/// MultilabelMargin BW.
pub fn baracuda_kernels_loss_multilabel_margin_backward_f32_run(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32, scale_scalar: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_multilabel_margin_backward_f32_can_implement` (baracuda kernels loss multilabel margin backward f32 can implement).
pub fn baracuda_kernels_loss_multilabel_margin_backward_f32_can_implement(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32, scale_scalar: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *const c_void,
) -> i32;
/// MultilabelMargin BW, f16.
pub fn baracuda_kernels_loss_multilabel_margin_backward_f16_run(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32, scale_scalar: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_multilabel_margin_backward_f16_can_implement` (baracuda kernels loss multilabel margin backward f16 can implement).
pub fn baracuda_kernels_loss_multilabel_margin_backward_f16_can_implement(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32, scale_scalar: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *const c_void,
) -> i32;
/// MultilabelMargin BW, bf16.
pub fn baracuda_kernels_loss_multilabel_margin_backward_bf16_run(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32, scale_scalar: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_multilabel_margin_backward_bf16_can_implement` (baracuda kernels loss multilabel margin backward bf16 can implement).
pub fn baracuda_kernels_loss_multilabel_margin_backward_bf16_can_implement(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32, scale_scalar: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *const c_void,
) -> i32;
/// MultilabelMargin BW, f64.
pub fn baracuda_kernels_loss_multilabel_margin_backward_f64_run(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32, scale_scalar: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_multilabel_margin_backward_f64_can_implement` (baracuda kernels loss multilabel margin backward f64 can implement).
pub fn baracuda_kernels_loss_multilabel_margin_backward_f64_can_implement(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32, scale_scalar: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *const c_void,
) -> i32;
/// MultilabelSoftMargin FW.
pub fn baracuda_kernels_loss_multilabel_soft_margin_f32_run(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32,
input: *const c_void, target: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_multilabel_soft_margin_f32_can_implement` (baracuda kernels loss multilabel soft margin f32 can implement).
pub fn baracuda_kernels_loss_multilabel_soft_margin_f32_can_implement(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32,
input: *const c_void, target: *const c_void, out: *const c_void,
) -> i32;
/// MultilabelSoftMargin FW, f16.
pub fn baracuda_kernels_loss_multilabel_soft_margin_f16_run(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32,
input: *const c_void, target: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_multilabel_soft_margin_f16_can_implement` (baracuda kernels loss multilabel soft margin f16 can implement).
pub fn baracuda_kernels_loss_multilabel_soft_margin_f16_can_implement(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32,
input: *const c_void, target: *const c_void, out: *const c_void,
) -> i32;
/// MultilabelSoftMargin FW, bf16.
pub fn baracuda_kernels_loss_multilabel_soft_margin_bf16_run(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32,
input: *const c_void, target: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_multilabel_soft_margin_bf16_can_implement` (baracuda kernels loss multilabel soft margin bf16 can implement).
pub fn baracuda_kernels_loss_multilabel_soft_margin_bf16_can_implement(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32,
input: *const c_void, target: *const c_void, out: *const c_void,
) -> i32;
/// MultilabelSoftMargin FW, f64.
pub fn baracuda_kernels_loss_multilabel_soft_margin_f64_run(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32,
input: *const c_void, target: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_multilabel_soft_margin_f64_can_implement` (baracuda kernels loss multilabel soft margin f64 can implement).
pub fn baracuda_kernels_loss_multilabel_soft_margin_f64_can_implement(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32,
input: *const c_void, target: *const c_void, out: *const c_void,
) -> i32;
/// MultilabelSoftMargin BW.
pub fn baracuda_kernels_loss_multilabel_soft_margin_backward_f32_run(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32, scale_scalar: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_multilabel_soft_margin_backward_f32_can_implement` (baracuda kernels loss multilabel soft margin backward f32 can implement).
pub fn baracuda_kernels_loss_multilabel_soft_margin_backward_f32_can_implement(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32, scale_scalar: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *const c_void,
) -> i32;
/// MultilabelSoftMargin BW, f16.
pub fn baracuda_kernels_loss_multilabel_soft_margin_backward_f16_run(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32, scale_scalar: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_multilabel_soft_margin_backward_f16_can_implement` (baracuda kernels loss multilabel soft margin backward f16 can implement).
pub fn baracuda_kernels_loss_multilabel_soft_margin_backward_f16_can_implement(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32, scale_scalar: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *const c_void,
) -> i32;
/// MultilabelSoftMargin BW, bf16.
pub fn baracuda_kernels_loss_multilabel_soft_margin_backward_bf16_run(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32, scale_scalar: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_multilabel_soft_margin_backward_bf16_can_implement` (baracuda kernels loss multilabel soft margin backward bf16 can implement).
pub fn baracuda_kernels_loss_multilabel_soft_margin_backward_bf16_can_implement(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32, scale_scalar: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *const c_void,
) -> i32;
/// MultilabelSoftMargin BW, f64.
pub fn baracuda_kernels_loss_multilabel_soft_margin_backward_f64_run(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32, scale_scalar: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_multilabel_soft_margin_backward_f64_can_implement` (baracuda kernels loss multilabel soft margin backward f64 can implement).
pub fn baracuda_kernels_loss_multilabel_soft_margin_backward_f64_can_implement(
n_rows: i64, class_extent: i32,
row_stride_in: i64, row_stride_tgt: i64,
reduction_mode: i32, scale_scalar: f32,
input: *const c_void, target: *const c_void, dy: *const c_void, dinput: *const c_void,
) -> i32;
// -------------------------------------------------------------------------
// CTCLoss — Milestone 5.5 (Phase 5 final deferral).
//
// log_probs is `T[T, N, C]` row-major. targets is `i64[N, S]`.
// input_lengths / target_lengths are `i64[N]`. The kernel runs
// forward DP on the lattice once per batch sample (one CUDA block
// per sample). `alpha_ws` is a workspace of accumulator type
// (f32 for {f32, f16, bf16}; f64 for f64) shaped `[T, N, 2·S_max+1]`.
// `workspace` carries the per-sample loss buffer ([N] floats/doubles).
// -------------------------------------------------------------------------
/// CTCLoss FW, f32.
pub fn baracuda_kernels_loss_ctc_f32_run(
max_time: i32, batch_size: i32, num_classes: i32, max_target_len: i32,
blank: i32, reduction_mode: i32, zero_infinity: i32,
log_probs: *const c_void, targets: *const c_void,
input_lengths: *const c_void, target_lengths: *const c_void,
alpha_ws: *mut c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// CTCLoss FW `_can_implement`, f32. Validates `num_classes <= 32`,
/// `max_target_len <= 256`, `blank ∈ [0, num_classes)`, `reduction_mode ∈ [0,2]`.
pub fn baracuda_kernels_loss_ctc_f32_can_implement(
max_time: i32, batch_size: i32, num_classes: i32, max_target_len: i32,
blank: i32, reduction_mode: i32, zero_infinity: i32,
log_probs: *const c_void, targets: *const c_void,
input_lengths: *const c_void, target_lengths: *const c_void,
alpha_ws: *const c_void, out: *const c_void,
) -> i32;
/// CTCLoss FW, f16.
pub fn baracuda_kernels_loss_ctc_f16_run(
max_time: i32, batch_size: i32, num_classes: i32, max_target_len: i32,
blank: i32, reduction_mode: i32, zero_infinity: i32,
log_probs: *const c_void, targets: *const c_void,
input_lengths: *const c_void, target_lengths: *const c_void,
alpha_ws: *mut c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// CTCLoss FW `_can_implement`, f16 (F32_ACC).
pub fn baracuda_kernels_loss_ctc_f16_can_implement(
max_time: i32, batch_size: i32, num_classes: i32, max_target_len: i32,
blank: i32, reduction_mode: i32, zero_infinity: i32,
log_probs: *const c_void, targets: *const c_void,
input_lengths: *const c_void, target_lengths: *const c_void,
alpha_ws: *const c_void, out: *const c_void,
) -> i32;
/// CTCLoss FW, bf16.
pub fn baracuda_kernels_loss_ctc_bf16_run(
max_time: i32, batch_size: i32, num_classes: i32, max_target_len: i32,
blank: i32, reduction_mode: i32, zero_infinity: i32,
log_probs: *const c_void, targets: *const c_void,
input_lengths: *const c_void, target_lengths: *const c_void,
alpha_ws: *mut c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// CTCLoss FW `_can_implement`, bf16 (F32_ACC).
pub fn baracuda_kernels_loss_ctc_bf16_can_implement(
max_time: i32, batch_size: i32, num_classes: i32, max_target_len: i32,
blank: i32, reduction_mode: i32, zero_infinity: i32,
log_probs: *const c_void, targets: *const c_void,
input_lengths: *const c_void, target_lengths: *const c_void,
alpha_ws: *const c_void, out: *const c_void,
) -> i32;
/// CTCLoss FW, f64.
pub fn baracuda_kernels_loss_ctc_f64_run(
max_time: i32, batch_size: i32, num_classes: i32, max_target_len: i32,
blank: i32, reduction_mode: i32, zero_infinity: i32,
log_probs: *const c_void, targets: *const c_void,
input_lengths: *const c_void, target_lengths: *const c_void,
alpha_ws: *mut c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// CTCLoss FW `_can_implement`, f64 (F64_ACC).
pub fn baracuda_kernels_loss_ctc_f64_can_implement(
max_time: i32, batch_size: i32, num_classes: i32, max_target_len: i32,
blank: i32, reduction_mode: i32, zero_infinity: i32,
log_probs: *const c_void, targets: *const c_void,
input_lengths: *const c_void, target_lengths: *const c_void,
alpha_ws: *const c_void, out: *const c_void,
) -> i32;
/// CTCLoss BW, f32.
pub fn baracuda_kernels_loss_ctc_backward_f32_run(
max_time: i32, batch_size: i32, num_classes: i32, max_target_len: i32,
blank: i32, reduction_mode: i32, zero_infinity: i32, inv_denom: f32,
log_probs: *const c_void, targets: *const c_void,
input_lengths: *const c_void, target_lengths: *const c_void,
alpha_ws: *const c_void, per_sample_loss: *const c_void,
dloss: *const c_void, dlog_probs: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// CTCLoss BW `_can_implement`, f32.
pub fn baracuda_kernels_loss_ctc_backward_f32_can_implement(
max_time: i32, batch_size: i32, num_classes: i32, max_target_len: i32,
blank: i32, reduction_mode: i32, zero_infinity: i32, inv_denom: f32,
log_probs: *const c_void, targets: *const c_void,
input_lengths: *const c_void, target_lengths: *const c_void,
alpha_ws: *const c_void, per_sample_loss: *const c_void,
dloss: *const c_void, dlog_probs: *const c_void,
) -> i32;
/// CTCLoss BW, f16.
pub fn baracuda_kernels_loss_ctc_backward_f16_run(
max_time: i32, batch_size: i32, num_classes: i32, max_target_len: i32,
blank: i32, reduction_mode: i32, zero_infinity: i32, inv_denom: f32,
log_probs: *const c_void, targets: *const c_void,
input_lengths: *const c_void, target_lengths: *const c_void,
alpha_ws: *const c_void, per_sample_loss: *const c_void,
dloss: *const c_void, dlog_probs: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// CTCLoss BW `_can_implement`, f16.
pub fn baracuda_kernels_loss_ctc_backward_f16_can_implement(
max_time: i32, batch_size: i32, num_classes: i32, max_target_len: i32,
blank: i32, reduction_mode: i32, zero_infinity: i32, inv_denom: f32,
log_probs: *const c_void, targets: *const c_void,
input_lengths: *const c_void, target_lengths: *const c_void,
alpha_ws: *const c_void, per_sample_loss: *const c_void,
dloss: *const c_void, dlog_probs: *const c_void,
) -> i32;
/// CTCLoss BW, bf16.
pub fn baracuda_kernels_loss_ctc_backward_bf16_run(
max_time: i32, batch_size: i32, num_classes: i32, max_target_len: i32,
blank: i32, reduction_mode: i32, zero_infinity: i32, inv_denom: f32,
log_probs: *const c_void, targets: *const c_void,
input_lengths: *const c_void, target_lengths: *const c_void,
alpha_ws: *const c_void, per_sample_loss: *const c_void,
dloss: *const c_void, dlog_probs: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// CTCLoss BW `_can_implement`, bf16.
pub fn baracuda_kernels_loss_ctc_backward_bf16_can_implement(
max_time: i32, batch_size: i32, num_classes: i32, max_target_len: i32,
blank: i32, reduction_mode: i32, zero_infinity: i32, inv_denom: f32,
log_probs: *const c_void, targets: *const c_void,
input_lengths: *const c_void, target_lengths: *const c_void,
alpha_ws: *const c_void, per_sample_loss: *const c_void,
dloss: *const c_void, dlog_probs: *const c_void,
) -> i32;
/// CTCLoss BW, f64.
pub fn baracuda_kernels_loss_ctc_backward_f64_run(
max_time: i32, batch_size: i32, num_classes: i32, max_target_len: i32,
blank: i32, reduction_mode: i32, zero_infinity: i32, inv_denom: f32,
log_probs: *const c_void, targets: *const c_void,
input_lengths: *const c_void, target_lengths: *const c_void,
alpha_ws: *const c_void, per_sample_loss: *const c_void,
dloss: *const c_void, dlog_probs: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// CTCLoss BW `_can_implement`, f64 (F64_ACC).
pub fn baracuda_kernels_loss_ctc_backward_f64_can_implement(
max_time: i32, batch_size: i32, num_classes: i32, max_target_len: i32,
blank: i32, reduction_mode: i32, zero_infinity: i32, inv_denom: f32,
log_probs: *const c_void, targets: *const c_void,
input_lengths: *const c_void, target_lengths: *const c_void,
alpha_ws: *const c_void, per_sample_loss: *const c_void,
dloss: *const c_void, dlog_probs: *const c_void,
) -> i32;
/// PReLU FW, f32. ABI: `(numel, channel_stride, channel_extent,
/// scalar_weight, x, weight, y, workspace, workspace_bytes, stream)`.
pub fn baracuda_kernels_prelu_f32_run(
numel: i64, channel_stride: i64,
channel_extent: i32, scalar_weight: i32,
x: *const c_void, weight: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_prelu_f32_can_implement` (baracuda kernels prelu f32 can implement).
pub fn baracuda_kernels_prelu_f32_can_implement(
numel: i64,
channel_stride: i64,
channel_extent: i32,
scalar_weight: i32,
x: *const c_void,
weight: *const c_void,
y: *const c_void,
) -> i32;
/// PReLU FW, f16.
pub fn baracuda_kernels_prelu_f16_run(
numel: i64, channel_stride: i64,
channel_extent: i32, scalar_weight: i32,
x: *const c_void, weight: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_prelu_f16_can_implement` (baracuda kernels prelu f16 can implement).
pub fn baracuda_kernels_prelu_f16_can_implement(
numel: i64,
channel_stride: i64,
channel_extent: i32,
scalar_weight: i32,
x: *const c_void,
weight: *const c_void,
y: *const c_void,
) -> i32;
/// PReLU FW, bf16.
pub fn baracuda_kernels_prelu_bf16_run(
numel: i64, channel_stride: i64,
channel_extent: i32, scalar_weight: i32,
x: *const c_void, weight: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_prelu_bf16_can_implement` (baracuda kernels prelu bf16 can implement).
pub fn baracuda_kernels_prelu_bf16_can_implement(
numel: i64,
channel_stride: i64,
channel_extent: i32,
scalar_weight: i32,
x: *const c_void,
weight: *const c_void,
y: *const c_void,
) -> i32;
/// PReLU FW, f64.
pub fn baracuda_kernels_prelu_f64_run(
numel: i64, channel_stride: i64,
channel_extent: i32, scalar_weight: i32,
x: *const c_void, weight: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_prelu_f64_can_implement` (baracuda kernels prelu f64 can implement).
pub fn baracuda_kernels_prelu_f64_can_implement(
numel: i64,
channel_stride: i64,
channel_extent: i32,
scalar_weight: i32,
x: *const c_void,
weight: *const c_void,
y: *const c_void,
) -> i32;
/// PReLU BW, f32. ABI: `(numel, channel_stride, channel_extent,
/// scalar_weight, dy, x, weight, dx, dweight, workspace, workspace_bytes, stream)`.
pub fn baracuda_kernels_prelu_backward_f32_run(
numel: i64, channel_stride: i64,
channel_extent: i32, scalar_weight: i32,
dy: *const c_void, x: *const c_void, weight: *const c_void,
dx: *mut c_void, dweight: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_prelu_backward_f32`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_prelu_backward_f32_can_implement(
numel: i64,
channel_stride: i64,
channel_extent: i32,
scalar_weight: i32,
dy: *const c_void,
x: *const c_void,
weight: *const c_void,
dx: *const c_void,
dweight: *const c_void,
) -> i32;
/// PReLU BW, f16.
pub fn baracuda_kernels_prelu_backward_f16_run(
numel: i64, channel_stride: i64,
channel_extent: i32, scalar_weight: i32,
dy: *const c_void, x: *const c_void, weight: *const c_void,
dx: *mut c_void, dweight: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_prelu_backward_f16`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_prelu_backward_f16_can_implement(
numel: i64,
channel_stride: i64,
channel_extent: i32,
scalar_weight: i32,
dy: *const c_void,
x: *const c_void,
weight: *const c_void,
dx: *const c_void,
dweight: *const c_void,
) -> i32;
/// PReLU BW, bf16.
pub fn baracuda_kernels_prelu_backward_bf16_run(
numel: i64, channel_stride: i64,
channel_extent: i32, scalar_weight: i32,
dy: *const c_void, x: *const c_void, weight: *const c_void,
dx: *mut c_void, dweight: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_prelu_backward_bf16`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_prelu_backward_bf16_can_implement(
numel: i64,
channel_stride: i64,
channel_extent: i32,
scalar_weight: i32,
dy: *const c_void,
x: *const c_void,
weight: *const c_void,
dx: *const c_void,
dweight: *const c_void,
) -> i32;
/// PReLU BW, f64.
pub fn baracuda_kernels_prelu_backward_f64_run(
numel: i64, channel_stride: i64,
channel_extent: i32, scalar_weight: i32,
dy: *const c_void, x: *const c_void, weight: *const c_void,
dx: *mut c_void, dweight: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_prelu_backward_f64`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_prelu_backward_f64_can_implement(
numel: i64,
channel_stride: i64,
channel_extent: i32,
scalar_weight: i32,
dy: *const c_void,
x: *const c_void,
weight: *const c_void,
dx: *const c_void,
dweight: *const c_void,
) -> i32;
/// Soft-target CrossEntropy BW, f16.
pub fn baracuda_kernels_loss_cross_entropy_soft_backward_f16_run(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
row_stride_target: i64,
reduction_mode: i32,
scale_scalar: f32,
input: *const c_void,
target: *const c_void,
dy: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_cross_entropy_soft_backward_f16_can_implement` (baracuda kernels loss cross entropy soft backward f16 can implement).
pub fn baracuda_kernels_loss_cross_entropy_soft_backward_f16_can_implement(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
row_stride_target: i64,
reduction_mode: i32,
scale_scalar: f32,
input: *const c_void,
target: *const c_void,
dy: *const c_void,
dinput: *const c_void,
) -> i32;
/// Soft-target CrossEntropy BW, bf16.
pub fn baracuda_kernels_loss_cross_entropy_soft_backward_bf16_run(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
row_stride_target: i64,
reduction_mode: i32,
scale_scalar: f32,
input: *const c_void,
target: *const c_void,
dy: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_cross_entropy_soft_backward_bf16_can_implement` (baracuda kernels loss cross entropy soft backward bf16 can implement).
pub fn baracuda_kernels_loss_cross_entropy_soft_backward_bf16_can_implement(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
row_stride_target: i64,
reduction_mode: i32,
scale_scalar: f32,
input: *const c_void,
target: *const c_void,
dy: *const c_void,
dinput: *const c_void,
) -> i32;
/// Soft-target CrossEntropy BW, f64.
pub fn baracuda_kernels_loss_cross_entropy_soft_backward_f64_run(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
row_stride_target: i64,
reduction_mode: i32,
scale_scalar: f32,
input: *const c_void,
target: *const c_void,
dy: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_loss_cross_entropy_soft_backward_f64_can_implement` (baracuda kernels loss cross entropy soft backward f64 can implement).
pub fn baracuda_kernels_loss_cross_entropy_soft_backward_f64_can_implement(
n_rows: i64,
class_extent: i32,
row_stride_input: i64,
row_stride_target: i64,
reduction_mode: i32,
scale_scalar: f32,
input: *const c_void,
target: *const c_void,
dy: *const c_void,
dinput: *const c_void,
) -> i32;
}
// ============================================================================
// Normalization family (Category G) — RMSNorm + LayerNorm + BatchNorm
// + GroupNorm (FW + BW).
//
// **RMSNorm / LayerNorm** use a multi-axis bitmask scheme: `norm_axes_mask`
// is an int32 bitmask (bit `d` set ⇒ axis `d` is normalized). The mask
// must be a suffix of `[0, rank)` (axes contiguous from the right —
// PyTorch's `normalized_shape` convention; validated in `can_implement`
// on the Rust side). `norm_total_extent` is the product of all axes in
// the mask. Per-output-cell two-pass row-stat scheme.
//
// **BatchNorm / GroupNorm** pre-collapse the input to logical
// `[N, C, S]` (S = product of spatial dims) and use a three-stage
// scheme: stage-1 per-group stat reduction, stage-2 per-cell normalize,
// stage-3 per-channel affine grads. group_kind selects BN (0) or GN/IN
// (1).
//
// BW launchers internally fire multiple kernels but all reductions are
// done via warp shuffles + smem — fully deterministic, no atomic-adds.
// ============================================================================
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// RMSNorm FW, f32. `y = x / sqrt(mean(x², over norm_axes) + eps) * gamma`.
/// `norm_axes_mask` is a bitmask over input axes (suffix of `[0,
/// rank)`); `norm_total_extent` is the product of those axes'
/// extents. `gamma` may be null (treated as 1). `rms_out` shape
/// equals input shape with norm axes collapsed to 1; only the
/// slot at inner_lin == 0 within each row is written.
pub fn baracuda_kernels_rms_norm_f32_run(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_rms: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
y: *mut c_void,
rms_out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rms_norm_f32_can_implement` (baracuda kernels rms norm f32 can implement).
pub fn baracuda_kernels_rms_norm_f32_can_implement(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_rms: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
y: *const c_void,
rms_out: *const c_void,
) -> i32;
/// RMSNorm FW, f16. f32 accumulator inside the kernel.
pub fn baracuda_kernels_rms_norm_f16_run(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_rms: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
y: *mut c_void,
rms_out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rms_norm_f16_can_implement` (baracuda kernels rms norm f16 can implement).
pub fn baracuda_kernels_rms_norm_f16_can_implement(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_rms: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
y: *const c_void,
rms_out: *const c_void,
) -> i32;
/// RMSNorm FW, bf16. f32 accumulator inside the kernel.
pub fn baracuda_kernels_rms_norm_bf16_run(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_rms: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
y: *mut c_void,
rms_out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rms_norm_bf16_can_implement` (baracuda kernels rms norm bf16 can implement).
pub fn baracuda_kernels_rms_norm_bf16_can_implement(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_rms: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
y: *const c_void,
rms_out: *const c_void,
) -> i32;
/// RMSNorm FW, f64.
pub fn baracuda_kernels_rms_norm_f64_run(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_rms: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
y: *mut c_void,
rms_out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rms_norm_f64_can_implement` (baracuda kernels rms norm f64 can implement).
pub fn baracuda_kernels_rms_norm_f64_can_implement(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_rms: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
y: *const c_void,
rms_out: *const c_void,
) -> i32;
/// RMSNorm BW, f32. Computes `dx` and (when `dgamma != null`)
/// `dgamma[i] = Σ over outer cells dy[..., i] · (x[..., i] / rms[..., 0])`
/// where `i` ranges over the joint normalized region of length
/// `norm_total_extent`.
pub fn baracuda_kernels_rms_norm_backward_f32_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_rms: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
rms: *const c_void,
dx: *mut c_void,
dgamma: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rms_norm_backward_f32_can_implement` (baracuda kernels rms norm backward f32 can implement).
pub fn baracuda_kernels_rms_norm_backward_f32_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_rms: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
rms: *const c_void,
dx: *const c_void,
dgamma: *const c_void,
) -> i32;
/// RMSNorm BW, f16.
pub fn baracuda_kernels_rms_norm_backward_f16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_rms: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
rms: *const c_void,
dx: *mut c_void,
dgamma: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rms_norm_backward_f16_can_implement` (baracuda kernels rms norm backward f16 can implement).
pub fn baracuda_kernels_rms_norm_backward_f16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_rms: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
rms: *const c_void,
dx: *const c_void,
dgamma: *const c_void,
) -> i32;
/// RMSNorm BW, bf16.
pub fn baracuda_kernels_rms_norm_backward_bf16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_rms: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
rms: *const c_void,
dx: *mut c_void,
dgamma: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rms_norm_backward_bf16_can_implement` (baracuda kernels rms norm backward bf16 can implement).
pub fn baracuda_kernels_rms_norm_backward_bf16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_rms: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
rms: *const c_void,
dx: *const c_void,
dgamma: *const c_void,
) -> i32;
/// RMSNorm BW, f64.
pub fn baracuda_kernels_rms_norm_backward_f64_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_rms: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
rms: *const c_void,
dx: *mut c_void,
dgamma: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rms_norm_backward_f64_can_implement` (baracuda kernels rms norm backward f64 can implement).
pub fn baracuda_kernels_rms_norm_backward_f64_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_rms: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
rms: *const c_void,
dx: *const c_void,
dgamma: *const c_void,
) -> i32;
/// LayerNorm FW, f32. `y = (x - mean) / sqrt(var + eps) * gamma + beta`.
/// `gamma` / `beta` independently optional. Biased ("population")
/// variance. Save buffers `mean_out` / `inv_std_out` share
/// `stride_save`, each shape == input with norm axes collapsed to 1.
pub fn baracuda_kernels_layer_norm_f32_run(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_save: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *mut c_void,
mean_out: *mut c_void,
inv_std_out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_layer_norm_f32_can_implement` (baracuda kernels layer norm f32 can implement).
pub fn baracuda_kernels_layer_norm_f32_can_implement(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_save: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *const c_void,
mean_out: *const c_void,
inv_std_out: *const c_void,
) -> i32;
/// LayerNorm FW, f16. f32 accumulator inside the kernel.
pub fn baracuda_kernels_layer_norm_f16_run(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_save: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *mut c_void,
mean_out: *mut c_void,
inv_std_out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_layer_norm_f16_can_implement` (baracuda kernels layer norm f16 can implement).
pub fn baracuda_kernels_layer_norm_f16_can_implement(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_save: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *const c_void,
mean_out: *const c_void,
inv_std_out: *const c_void,
) -> i32;
/// LayerNorm FW, bf16. f32 accumulator inside the kernel.
pub fn baracuda_kernels_layer_norm_bf16_run(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_save: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *mut c_void,
mean_out: *mut c_void,
inv_std_out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_layer_norm_bf16_can_implement` (baracuda kernels layer norm bf16 can implement).
pub fn baracuda_kernels_layer_norm_bf16_can_implement(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_save: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *const c_void,
mean_out: *const c_void,
inv_std_out: *const c_void,
) -> i32;
/// LayerNorm FW, f64.
pub fn baracuda_kernels_layer_norm_f64_run(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_save: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *mut c_void,
mean_out: *mut c_void,
inv_std_out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_layer_norm_f64_can_implement` (baracuda kernels layer norm f64 can implement).
pub fn baracuda_kernels_layer_norm_f64_can_implement(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_save: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *const c_void,
mean_out: *const c_void,
inv_std_out: *const c_void,
) -> i32;
/// LayerNorm BW, f32. Computes `dx` and (when non-null) `dgamma` /
/// `dbeta` reductions. Caller passes saved `mean` + `inv_std` from FW.
pub fn baracuda_kernels_layer_norm_backward_f32_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_save: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
mean_in: *const c_void,
inv_std_in: *const c_void,
dx: *mut c_void,
dgamma: *mut c_void,
dbeta: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_layer_norm_backward_f32_can_implement` (baracuda kernels layer norm backward f32 can implement).
pub fn baracuda_kernels_layer_norm_backward_f32_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_save: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
mean_in: *const c_void,
inv_std_in: *const c_void,
dx: *const c_void,
dgamma: *const c_void,
dbeta: *const c_void,
) -> i32;
/// LayerNorm BW, f16.
pub fn baracuda_kernels_layer_norm_backward_f16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_save: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
mean_in: *const c_void,
inv_std_in: *const c_void,
dx: *mut c_void,
dgamma: *mut c_void,
dbeta: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_layer_norm_backward_f16_can_implement` (baracuda kernels layer norm backward f16 can implement).
pub fn baracuda_kernels_layer_norm_backward_f16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_save: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
mean_in: *const c_void,
inv_std_in: *const c_void,
dx: *const c_void,
dgamma: *const c_void,
dbeta: *const c_void,
) -> i32;
/// LayerNorm BW, bf16.
pub fn baracuda_kernels_layer_norm_backward_bf16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_save: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
mean_in: *const c_void,
inv_std_in: *const c_void,
dx: *mut c_void,
dgamma: *mut c_void,
dbeta: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_layer_norm_backward_bf16_can_implement` (baracuda kernels layer norm backward bf16 can implement).
pub fn baracuda_kernels_layer_norm_backward_bf16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_save: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
mean_in: *const c_void,
inv_std_in: *const c_void,
dx: *const c_void,
dgamma: *const c_void,
dbeta: *const c_void,
) -> i32;
/// LayerNorm BW, f64.
pub fn baracuda_kernels_layer_norm_backward_f64_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_save: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
mean_in: *const c_void,
inv_std_in: *const c_void,
dx: *mut c_void,
dgamma: *mut c_void,
dbeta: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_layer_norm_backward_f64_can_implement` (baracuda kernels layer norm backward f64 can implement).
pub fn baracuda_kernels_layer_norm_backward_f64_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_save: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
mean_in: *const c_void,
inv_std_in: *const c_void,
dx: *const c_void,
dgamma: *const c_void,
dbeta: *const c_void,
) -> i32;
}
// ============================================================================
// BatchNorm + GroupNorm — Phase 5.1 Norm family completion.
//
// Caller pre-collapses input to logical [N, C, S] (S = product of
// spatial dims). Channel axis is axis 1 of the original tensor
// (PyTorch convention). All BN/GN kernels share the same launcher ABI;
// `group_kind` selects BN (0, one group per channel) vs GN/IN
// (1, num_groups caller-supplied; InstanceNorm = num_groups == c_extent).
// ============================================================================
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// BatchNorm FW, f32. Training mode: computes per-channel
/// `(mean, inv_std)` from the batch + spatial cells, writes them to
/// `saved_mean` / `saved_rstd` for BW. `gamma` / `beta` optional
/// (both supplied together per PyTorch convention).
///
/// **In-place aliasing contract (Phase 65d)**: aliasing `y` with
/// `x` (`x_ptr == y_ptr`) is safe. BN runs as two cooperating
/// kernels: stage-1 reads `x` and writes only `saved_mean` /
/// `saved_rstd` (never touches `y`); stage-2 reads each `x[i]`
/// into a register before writing `y[i]` in the same thread (a
/// strict single-read-then-single-write per cell). Both stages
/// are aliasing-safe regardless of dtype — **f64 is in-place safe
/// here**, unlike Phase 65b/c normalizers where f64 falls back to
/// a multi-pass-global path. `saved_mean` / `saved_rstd` must be
/// distinct from `x` / `y` (their layout is `[group_count]`, a
/// different shape entirely). The same contract covers the
/// matching `bf16` / `f16` / `f64` FFI siblings and the GroupNorm
/// family (`group_kind = 1`) including the InstanceNorm dispatch
/// (`num_groups = c_extent`).
pub fn baracuda_kernels_batch_norm_f32_run(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
eps: f32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *mut c_void,
saved_mean: *mut c_void,
saved_rstd: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_batch_norm_f32_can_implement` (baracuda kernels batch norm f32 can implement).
pub fn baracuda_kernels_batch_norm_f32_can_implement(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
eps: f32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *const c_void,
saved_mean: *const c_void,
saved_rstd: *const c_void,
) -> i32;
/// BatchNorm FW, f16.
pub fn baracuda_kernels_batch_norm_f16_run(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
eps: f32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *mut c_void,
saved_mean: *mut c_void,
saved_rstd: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_batch_norm_f16_can_implement` (baracuda kernels batch norm f16 can implement).
pub fn baracuda_kernels_batch_norm_f16_can_implement(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
eps: f32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *const c_void,
saved_mean: *const c_void,
saved_rstd: *const c_void,
) -> i32;
/// BatchNorm FW, bf16.
pub fn baracuda_kernels_batch_norm_bf16_run(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
eps: f32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *mut c_void,
saved_mean: *mut c_void,
saved_rstd: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_batch_norm_bf16_can_implement` (baracuda kernels batch norm bf16 can implement).
pub fn baracuda_kernels_batch_norm_bf16_can_implement(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
eps: f32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *const c_void,
saved_mean: *const c_void,
saved_rstd: *const c_void,
) -> i32;
/// BatchNorm FW, f64.
pub fn baracuda_kernels_batch_norm_f64_run(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
eps: f32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *mut c_void,
saved_mean: *mut c_void,
saved_rstd: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_batch_norm_f64_can_implement` (baracuda kernels batch norm f64 can implement).
pub fn baracuda_kernels_batch_norm_f64_can_implement(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
eps: f32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *const c_void,
saved_mean: *const c_void,
saved_rstd: *const c_void,
) -> i32;
/// BatchNorm BW, f32. Three-stage: per-group sum_dxh / sum_dxhxh,
/// per-cell dx, per-channel dgamma / dbeta. Requires workspace of
/// `2 * group_count * sizeof(float)` bytes for the stage-1 partial
/// sums (group_count = c_extent for BN).
pub fn baracuda_kernels_batch_norm_backward_f32_run(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
saved_mean: *const c_void,
saved_rstd: *const c_void,
dx: *mut c_void,
dgamma: *mut c_void,
dbeta: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_batch_norm_backward_f32_can_implement` (baracuda kernels batch norm backward f32 can implement).
pub fn baracuda_kernels_batch_norm_backward_f32_can_implement(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
saved_mean: *const c_void,
saved_rstd: *const c_void,
dx: *const c_void,
dgamma: *const c_void,
dbeta: *const c_void,
) -> i32;
/// BatchNorm BW, f16.
pub fn baracuda_kernels_batch_norm_backward_f16_run(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
saved_mean: *const c_void,
saved_rstd: *const c_void,
dx: *mut c_void,
dgamma: *mut c_void,
dbeta: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_batch_norm_backward_f16_can_implement` (baracuda kernels batch norm backward f16 can implement).
pub fn baracuda_kernels_batch_norm_backward_f16_can_implement(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
saved_mean: *const c_void,
saved_rstd: *const c_void,
dx: *const c_void,
dgamma: *const c_void,
dbeta: *const c_void,
) -> i32;
/// BatchNorm BW, bf16.
pub fn baracuda_kernels_batch_norm_backward_bf16_run(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
saved_mean: *const c_void,
saved_rstd: *const c_void,
dx: *mut c_void,
dgamma: *mut c_void,
dbeta: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_batch_norm_backward_bf16_can_implement` (baracuda kernels batch norm backward bf16 can implement).
pub fn baracuda_kernels_batch_norm_backward_bf16_can_implement(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
saved_mean: *const c_void,
saved_rstd: *const c_void,
dx: *const c_void,
dgamma: *const c_void,
dbeta: *const c_void,
) -> i32;
/// BatchNorm BW, f64.
pub fn baracuda_kernels_batch_norm_backward_f64_run(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
saved_mean: *const c_void,
saved_rstd: *const c_void,
dx: *mut c_void,
dgamma: *mut c_void,
dbeta: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_batch_norm_backward_f64_can_implement` (baracuda kernels batch norm backward f64 can implement).
pub fn baracuda_kernels_batch_norm_backward_f64_can_implement(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
saved_mean: *const c_void,
saved_rstd: *const c_void,
dx: *const c_void,
dgamma: *const c_void,
dbeta: *const c_void,
) -> i32;
/// GroupNorm FW, f32. Per `(sample, group)` mean / inv_std,
/// per-channel affine. `num_groups` must divide `c_extent`.
/// `group_kind = 1` selects the GN dispatch (also used by
/// InstanceNorm with `num_groups == c_extent`).
///
/// **In-place aliasing contract (Phase 65d)**: aliasing `y` with
/// `x` (`x_ptr == y_ptr`) is safe for **all dtypes including
/// f64**, by the same two-stage argument as the BN trailblazer
/// (see [`baracuda_kernels_batch_norm_f32_run`]). The InstanceNorm
/// dispatch (`num_groups == c_extent`) inherits this contract.
pub fn baracuda_kernels_group_norm_f32_run(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
eps: f32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *mut c_void,
saved_mean: *mut c_void,
saved_rstd: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_group_norm_f32_can_implement` (baracuda kernels group norm f32 can implement).
pub fn baracuda_kernels_group_norm_f32_can_implement(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
eps: f32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *const c_void,
saved_mean: *const c_void,
saved_rstd: *const c_void,
) -> i32;
/// GroupNorm FW, f16.
pub fn baracuda_kernels_group_norm_f16_run(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
eps: f32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *mut c_void,
saved_mean: *mut c_void,
saved_rstd: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_group_norm_f16_can_implement` (baracuda kernels group norm f16 can implement).
pub fn baracuda_kernels_group_norm_f16_can_implement(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
eps: f32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *const c_void,
saved_mean: *const c_void,
saved_rstd: *const c_void,
) -> i32;
/// GroupNorm FW, bf16.
pub fn baracuda_kernels_group_norm_bf16_run(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
eps: f32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *mut c_void,
saved_mean: *mut c_void,
saved_rstd: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_group_norm_bf16_can_implement` (baracuda kernels group norm bf16 can implement).
pub fn baracuda_kernels_group_norm_bf16_can_implement(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
eps: f32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *const c_void,
saved_mean: *const c_void,
saved_rstd: *const c_void,
) -> i32;
/// GroupNorm FW, f64.
pub fn baracuda_kernels_group_norm_f64_run(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
eps: f32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *mut c_void,
saved_mean: *mut c_void,
saved_rstd: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_group_norm_f64_can_implement` (baracuda kernels group norm f64 can implement).
pub fn baracuda_kernels_group_norm_f64_can_implement(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
eps: f32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *const c_void,
saved_mean: *const c_void,
saved_rstd: *const c_void,
) -> i32;
/// GroupNorm BW, f32. Workspace size: `2 * (n_extent * num_groups) *
/// sizeof(float)` bytes for the stage-1 partial sums.
pub fn baracuda_kernels_group_norm_backward_f32_run(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
saved_mean: *const c_void,
saved_rstd: *const c_void,
dx: *mut c_void,
dgamma: *mut c_void,
dbeta: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_group_norm_backward_f32_can_implement` (baracuda kernels group norm backward f32 can implement).
pub fn baracuda_kernels_group_norm_backward_f32_can_implement(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
saved_mean: *const c_void,
saved_rstd: *const c_void,
dx: *const c_void,
dgamma: *const c_void,
dbeta: *const c_void,
) -> i32;
/// GroupNorm BW, f16.
pub fn baracuda_kernels_group_norm_backward_f16_run(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
saved_mean: *const c_void,
saved_rstd: *const c_void,
dx: *mut c_void,
dgamma: *mut c_void,
dbeta: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_group_norm_backward_f16_can_implement` (baracuda kernels group norm backward f16 can implement).
pub fn baracuda_kernels_group_norm_backward_f16_can_implement(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
saved_mean: *const c_void,
saved_rstd: *const c_void,
dx: *const c_void,
dgamma: *const c_void,
dbeta: *const c_void,
) -> i32;
/// GroupNorm BW, bf16.
pub fn baracuda_kernels_group_norm_backward_bf16_run(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
saved_mean: *const c_void,
saved_rstd: *const c_void,
dx: *mut c_void,
dgamma: *mut c_void,
dbeta: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_group_norm_backward_bf16_can_implement` (baracuda kernels group norm backward bf16 can implement).
pub fn baracuda_kernels_group_norm_backward_bf16_can_implement(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
saved_mean: *const c_void,
saved_rstd: *const c_void,
dx: *const c_void,
dgamma: *const c_void,
dbeta: *const c_void,
) -> i32;
/// GroupNorm BW, f64.
pub fn baracuda_kernels_group_norm_backward_f64_run(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
saved_mean: *const c_void,
saved_rstd: *const c_void,
dx: *mut c_void,
dgamma: *mut c_void,
dbeta: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_group_norm_backward_f64_can_implement` (baracuda kernels group norm backward f64 can implement).
pub fn baracuda_kernels_group_norm_backward_f64_can_implement(
n_extent: i32,
c_extent: i32,
spatial_extent: i32,
num_groups: i32,
group_kind: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
saved_mean: *const c_void,
saved_rstd: *const c_void,
dx: *const c_void,
dgamma: *const c_void,
dbeta: *const c_void,
) -> i32;
}
// ============================================================================
// Trace — sum of the diagonal of a 2-D square matrix. Dispatched
// through `TracePlan<T>` (not `ReducePlan`) because trace reduces both
// axes via the i==i constraint rather than a single reduce_axis.
// ============================================================================
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Trace of a 2-D square matrix, f32. `y[0] = Σ x[i * stride_row + i * stride_col]`
/// for `i` in `0..rows`. Output is a single scalar.
pub fn baracuda_kernels_trace_f32_run(
rows: i32,
stride_row: i64,
stride_col: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_trace_f32_can_implement` (baracuda kernels trace f32 can implement).
pub fn baracuda_kernels_trace_f32_can_implement(
rows: i32,
stride_row: i64,
stride_col: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Trace, f16 (f32-detour accumulator).
pub fn baracuda_kernels_trace_f16_run(
rows: i32,
stride_row: i64,
stride_col: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_trace_f16_can_implement` (baracuda kernels trace f16 can implement).
pub fn baracuda_kernels_trace_f16_can_implement(
rows: i32,
stride_row: i64,
stride_col: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Trace, bf16 (f32-detour accumulator).
pub fn baracuda_kernels_trace_bf16_run(
rows: i32,
stride_row: i64,
stride_col: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_trace_bf16_can_implement` (baracuda kernels trace bf16 can implement).
pub fn baracuda_kernels_trace_bf16_can_implement(
rows: i32,
stride_row: i64,
stride_col: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Trace, f64.
pub fn baracuda_kernels_trace_f64_run(
rows: i32,
stride_row: i64,
stride_col: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_trace_f64_can_implement` (baracuda kernels trace f64 can implement).
pub fn baracuda_kernels_trace_f64_can_implement(
rows: i32,
stride_row: i64,
stride_col: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ============================================================================
// Shape / layout — Flip (Category N)
// ============================================================================
//
// **Aliasing (Phase 64)**: Flip, Roll, and Permute are **NOT** in-place
// safe. Unlike the elementwise unary/binary/ternary trailblazers, these
// shape ops have each thread reading from one coordinate and writing
// to a DIFFERENT coordinate (the flipped / shifted / permuted target).
// Two distinct threads can touch the same memory cell — one as a read,
// another as a write — and there's no guarantee about scheduling order.
// Same-pointer aliasing with these symbols is silent data corruption.
//
// If a caller needs an in-place flip/roll/permute, they must materialize
// the result into a fresh buffer and copy back (or use a bespoke
// in-place algorithm with paired-swap synchronization, which baracuda
// does not provide).
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Flip (reverse along selected axes), f32. `flip_axes[d]` is
/// 1 = reverse axis d, 0 = no-op.
///
/// **NOT in-place safe** — see the family-level aliasing note
/// above. Thread `i` reads from the flipped source coordinate and
/// writes to its output coordinate; two threads concurrently touch
/// each cell (one read, one write).
pub fn baracuda_kernels_flip_f32_run(
numel: i64,
rank: i32,
shape: *const i32,
flip_axes: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `flip_f32`.
pub fn baracuda_kernels_flip_f32_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
flip_axes: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Flip, f16. Pure element copy — no math.
pub fn baracuda_kernels_flip_f16_run(
numel: i64,
rank: i32,
shape: *const i32,
flip_axes: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `flip_f16`.
pub fn baracuda_kernels_flip_f16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
flip_axes: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Flip, bf16. Pure element copy — no math.
pub fn baracuda_kernels_flip_bf16_run(
numel: i64,
rank: i32,
shape: *const i32,
flip_axes: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `flip_bf16`.
pub fn baracuda_kernels_flip_bf16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
flip_axes: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Flip, f64. Pure element copy — no math.
pub fn baracuda_kernels_flip_f64_run(
numel: i64,
rank: i32,
shape: *const i32,
flip_axes: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `flip_f64`.
pub fn baracuda_kernels_flip_f64_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
flip_axes: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ============================================================================
// Shape / layout — Roll (Category N)
// ============================================================================
//
// **Aliasing**: NOT in-place safe — same reason as Flip (see the
// flip-family aliasing note above). Roll shifts each cell to a
// different output coordinate; same-pointer aliasing produces data
// corruption.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Roll (cyclic shift along axes), f32. `shifts[d]` is the shift
/// amount on axis d (positive or negative, mod shape[d]).
///
/// **NOT in-place safe** — thread `i` reads at source coord and
/// writes at shifted dest coord; same-pointer aliasing is unsafe.
pub fn baracuda_kernels_roll_f32_run(
numel: i64,
rank: i32,
shape: *const i32,
shifts: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `roll_f32`.
pub fn baracuda_kernels_roll_f32_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
shifts: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Roll, f16. Pure element copy — no math.
pub fn baracuda_kernels_roll_f16_run(
numel: i64,
rank: i32,
shape: *const i32,
shifts: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `roll_f16`.
pub fn baracuda_kernels_roll_f16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
shifts: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Roll, bf16. Pure element copy — no math.
pub fn baracuda_kernels_roll_bf16_run(
numel: i64,
rank: i32,
shape: *const i32,
shifts: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `roll_bf16`.
pub fn baracuda_kernels_roll_bf16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
shifts: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Roll, f64. Pure element copy — no math.
pub fn baracuda_kernels_roll_f64_run(
numel: i64,
rank: i32,
shape: *const i32,
shifts: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `roll_f64`.
pub fn baracuda_kernels_roll_f64_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
shifts: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ============================================================================
// Shape / layout — Contiguize (Phase 13.2, strided→contiguous copy)
// ============================================================================
//
// Byte-width-templated, dtype-agnostic at the kernel level. One symbol
// per natural element size (1 / 2 / 4 / 8 / 16 bytes) covers every
// byte-aligned baracuda dtype (f16, bf16, f32, f64, F32Strict, i32,
// i64, Bool, S8, U8, Fp8E4M3, Fp8E5M2, Complex32, Complex64). A
// separate nibble symbol handles S4 / U4 with a documented innermost-
// stride constraint (returns status 3 for unsupported source layouts).
//
// `source_strides` are SIGNED int64 (Flip ops produce negatives,
// BroadcastTo produces zeros). `source_offset` is in ELEMENTS, not
// bytes. Three host-side fast paths are baked into the launchers:
//
// 1. Source already contiguous + zero offset → single
// `cudaMemcpyAsync(DeviceToDevice)` of the full numel × element-size.
// 2. Innermost stride == 1 → per-outer-coord contiguous run copy
// (halves divmod cost vs the generic path).
// 3. Generic → one thread per output element; linear-index → multi-
// index unravel + signed-stride dot to compute source offset.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Contiguize, 1-byte element (Bool, S8, U8, Fp8E4M3, Fp8E5M2).
pub fn baracuda_kernels_contiguize_b1_run(
dest: *mut c_void,
source: *const c_void,
shape: *const i32,
source_strides: *const i64,
source_offset: i64,
rank: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_contiguize_b1_can_implement` (baracuda kernels contiguize b1 can implement).
pub fn baracuda_kernels_contiguize_b1_can_implement(
dest: *const c_void,
source: *const c_void,
shape: *const i32,
source_strides: *const i64,
source_offset: i64,
rank: i32,
) -> i32;
/// Contiguize, 2-byte element (f16, bf16).
pub fn baracuda_kernels_contiguize_b2_run(
dest: *mut c_void,
source: *const c_void,
shape: *const i32,
source_strides: *const i64,
source_offset: i64,
rank: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_contiguize_b2_can_implement` (baracuda kernels contiguize b2 can implement).
pub fn baracuda_kernels_contiguize_b2_can_implement(
dest: *const c_void,
source: *const c_void,
shape: *const i32,
source_strides: *const i64,
source_offset: i64,
rank: i32,
) -> i32;
/// Contiguize, 4-byte element (f32, F32Strict, i32).
pub fn baracuda_kernels_contiguize_b4_run(
dest: *mut c_void,
source: *const c_void,
shape: *const i32,
source_strides: *const i64,
source_offset: i64,
rank: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_contiguize_b4_can_implement` (baracuda kernels contiguize b4 can implement).
pub fn baracuda_kernels_contiguize_b4_can_implement(
dest: *const c_void,
source: *const c_void,
shape: *const i32,
source_strides: *const i64,
source_offset: i64,
rank: i32,
) -> i32;
/// Contiguize, 8-byte element (f64, i64, Complex32).
pub fn baracuda_kernels_contiguize_b8_run(
dest: *mut c_void,
source: *const c_void,
shape: *const i32,
source_strides: *const i64,
source_offset: i64,
rank: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_contiguize_b8_can_implement` (baracuda kernels contiguize b8 can implement).
pub fn baracuda_kernels_contiguize_b8_can_implement(
dest: *const c_void,
source: *const c_void,
shape: *const i32,
source_strides: *const i64,
source_offset: i64,
rank: i32,
) -> i32;
/// Contiguize, 16-byte element (Complex64).
pub fn baracuda_kernels_contiguize_b16_run(
dest: *mut c_void,
source: *const c_void,
shape: *const i32,
source_strides: *const i64,
source_offset: i64,
rank: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_contiguize_b16_can_implement` (baracuda kernels contiguize b16 can implement).
pub fn baracuda_kernels_contiguize_b16_can_implement(
dest: *const c_void,
source: *const c_void,
shape: *const i32,
source_strides: *const i64,
source_offset: i64,
rank: i32,
) -> i32;
/// Contiguize, nibble-packed (S4 / U4). Returns status 3
/// (Unsupported) when the source's innermost stride is not one of
/// `{1, -1, 2}` — i.e. when the source layout breaks nibble
/// alignment.
pub fn baracuda_kernels_contiguize_nibble_run(
dest: *mut c_void,
source: *const c_void,
shape: *const i32,
source_strides: *const i64,
source_offset: i64,
rank: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_contiguize_nibble_can_implement` (baracuda kernels contiguize nibble can implement).
pub fn baracuda_kernels_contiguize_nibble_can_implement(
dest: *const c_void,
source: *const c_void,
shape: *const i32,
source_strides: *const i64,
source_offset: i64,
rank: i32,
) -> i32;
}
// ============================================================================
// Shape / layout — Triu / Tril (Category N, triangular masks)
// ============================================================================
//
// `torch.triu(input, diagonal)` / `torch.tril(input, diagonal)` mask
// the last two dims of `input`: triu keeps `j >= i + diagonal`, tril
// keeps `j <= i + diagonal`. The batch prefix (anything before the
// last two dims) is masked independently. Output shape == input shape;
// the kernel zeros the off-side. One templated kernel body covers both
// ops via a Predicate functor; per-dtype symbols here.
//
// Differentiable: `d_input = triu(d_output, diagonal)` and `d_input =
// tril(d_output, diagonal)` — the backward plans dispatch back to
// these same symbols with `dy → input` and `dx → output`.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Triu, f16.
pub fn baracuda_kernels_triu_f16_run(
input: *const c_void,
output: *mut c_void,
shape: *const i32,
rank: i32,
diagonal: i32,
stream: *mut c_void,
) -> i32;
/// Implementability check for `triu_f16`.
pub fn baracuda_kernels_triu_f16_can_implement(
input: *const c_void,
output: *const c_void,
shape: *const i32,
rank: i32,
diagonal: i32,
) -> i32;
/// Triu, bf16.
pub fn baracuda_kernels_triu_bf16_run(
input: *const c_void,
output: *mut c_void,
shape: *const i32,
rank: i32,
diagonal: i32,
stream: *mut c_void,
) -> i32;
/// Implementability check for `triu_bf16`.
pub fn baracuda_kernels_triu_bf16_can_implement(
input: *const c_void,
output: *const c_void,
shape: *const i32,
rank: i32,
diagonal: i32,
) -> i32;
/// Triu, f32. This is the triu trailblazer — its aliasing contract
/// carries over to every other `triu_<dt>_run`, `triu_<dt>_strided_run`,
/// and the sibling `tril_*` family.
///
/// Returns the upper-triangular part of the input matrix, zeroing
/// the strict lower triangle below `diagonal`.
///
/// **Aliasing (Phase 64)**: aliasing `output` with `input` is safe
/// (apply the triangular mask in-place). The kernel body is
/// `output[k] = pred(i, j, diagonal) ? input[k] : 0` where each
/// thread `k` reads `input[k]` before writing `output[k]` at the
/// same linear index. No cross-thread dependencies. Callers
/// implementing `Op::TriuInplace` / `Op::TrilInplace` (e.g. for
/// causal-mask preparation in attention or LU decomposition
/// triangular extraction) can dispatch with `input_ptr ==
/// output_ptr` without a dedicated `_inplace_` variant.
pub fn baracuda_kernels_triu_f32_run(
input: *const c_void,
output: *mut c_void,
shape: *const i32,
rank: i32,
diagonal: i32,
stream: *mut c_void,
) -> i32;
/// Implementability check for `triu_f32`.
pub fn baracuda_kernels_triu_f32_can_implement(
input: *const c_void,
output: *const c_void,
shape: *const i32,
rank: i32,
diagonal: i32,
) -> i32;
/// Triu, f64.
pub fn baracuda_kernels_triu_f64_run(
input: *const c_void,
output: *mut c_void,
shape: *const i32,
rank: i32,
diagonal: i32,
stream: *mut c_void,
) -> i32;
/// Implementability check for `triu_f64`.
pub fn baracuda_kernels_triu_f64_can_implement(
input: *const c_void,
output: *const c_void,
shape: *const i32,
rank: i32,
diagonal: i32,
) -> i32;
/// Triu, i32.
pub fn baracuda_kernels_triu_i32_run(
input: *const c_void,
output: *mut c_void,
shape: *const i32,
rank: i32,
diagonal: i32,
stream: *mut c_void,
) -> i32;
/// Implementability check for `triu_i32`.
pub fn baracuda_kernels_triu_i32_can_implement(
input: *const c_void,
output: *const c_void,
shape: *const i32,
rank: i32,
diagonal: i32,
) -> i32;
/// Triu, i64.
pub fn baracuda_kernels_triu_i64_run(
input: *const c_void,
output: *mut c_void,
shape: *const i32,
rank: i32,
diagonal: i32,
stream: *mut c_void,
) -> i32;
/// Implementability check for `triu_i64`.
pub fn baracuda_kernels_triu_i64_can_implement(
input: *const c_void,
output: *const c_void,
shape: *const i32,
rank: i32,
diagonal: i32,
) -> i32;
/// Triu, Bool (storage = u8).
pub fn baracuda_kernels_triu_bool_run(
input: *const c_void,
output: *mut c_void,
shape: *const i32,
rank: i32,
diagonal: i32,
stream: *mut c_void,
) -> i32;
/// Implementability check for `triu_bool`.
pub fn baracuda_kernels_triu_bool_can_implement(
input: *const c_void,
output: *const c_void,
shape: *const i32,
rank: i32,
diagonal: i32,
) -> i32;
/// Tril, f16.
pub fn baracuda_kernels_tril_f16_run(
input: *const c_void,
output: *mut c_void,
shape: *const i32,
rank: i32,
diagonal: i32,
stream: *mut c_void,
) -> i32;
/// Implementability check for `tril_f16`.
pub fn baracuda_kernels_tril_f16_can_implement(
input: *const c_void,
output: *const c_void,
shape: *const i32,
rank: i32,
diagonal: i32,
) -> i32;
/// Tril, bf16.
pub fn baracuda_kernels_tril_bf16_run(
input: *const c_void,
output: *mut c_void,
shape: *const i32,
rank: i32,
diagonal: i32,
stream: *mut c_void,
) -> i32;
/// Implementability check for `tril_bf16`.
pub fn baracuda_kernels_tril_bf16_can_implement(
input: *const c_void,
output: *const c_void,
shape: *const i32,
rank: i32,
diagonal: i32,
) -> i32;
/// Tril, f32.
pub fn baracuda_kernels_tril_f32_run(
input: *const c_void,
output: *mut c_void,
shape: *const i32,
rank: i32,
diagonal: i32,
stream: *mut c_void,
) -> i32;
/// Implementability check for `tril_f32`.
pub fn baracuda_kernels_tril_f32_can_implement(
input: *const c_void,
output: *const c_void,
shape: *const i32,
rank: i32,
diagonal: i32,
) -> i32;
/// Tril, f64.
pub fn baracuda_kernels_tril_f64_run(
input: *const c_void,
output: *mut c_void,
shape: *const i32,
rank: i32,
diagonal: i32,
stream: *mut c_void,
) -> i32;
/// Implementability check for `tril_f64`.
pub fn baracuda_kernels_tril_f64_can_implement(
input: *const c_void,
output: *const c_void,
shape: *const i32,
rank: i32,
diagonal: i32,
) -> i32;
/// Tril, i32.
pub fn baracuda_kernels_tril_i32_run(
input: *const c_void,
output: *mut c_void,
shape: *const i32,
rank: i32,
diagonal: i32,
stream: *mut c_void,
) -> i32;
/// Implementability check for `tril_i32`.
pub fn baracuda_kernels_tril_i32_can_implement(
input: *const c_void,
output: *const c_void,
shape: *const i32,
rank: i32,
diagonal: i32,
) -> i32;
/// Tril, i64.
pub fn baracuda_kernels_tril_i64_run(
input: *const c_void,
output: *mut c_void,
shape: *const i32,
rank: i32,
diagonal: i32,
stream: *mut c_void,
) -> i32;
/// Implementability check for `tril_i64`.
pub fn baracuda_kernels_tril_i64_can_implement(
input: *const c_void,
output: *const c_void,
shape: *const i32,
rank: i32,
diagonal: i32,
) -> i32;
/// Tril, Bool (storage = u8).
pub fn baracuda_kernels_tril_bool_run(
input: *const c_void,
output: *mut c_void,
shape: *const i32,
rank: i32,
diagonal: i32,
stream: *mut c_void,
) -> i32;
/// Implementability check for `tril_bool`.
pub fn baracuda_kernels_tril_bool_can_implement(
input: *const c_void,
output: *const c_void,
shape: *const i32,
rank: i32,
diagonal: i32,
) -> i32;
}
// ============================================================================
// Shape / layout — Triu / Tril strided siblings (Phase 14.3)
// ============================================================================
//
// Companion launchers to the Triu / Tril contig fast path above. The
// Rust dispatcher picks contig vs strided at launch time based on
// whether both input and output are canonical row-major contiguous.
// The strided kernel reads the input at signed-stride offsets and
// writes the output at signed-stride offsets; the mask predicate
// (`j >= i + diagonal` for triu; `j <= i + diagonal` for tril) is
// evaluated on the last-two-dim coords as in the contig kernel.
//
// ABI:
// input — source device pointer (T const*).
// output — dest device pointer (T*).
// shape — points to `[i32; rank]` on the host stack.
// rank — i32, number of valid axes in [2, 8].
// stride_x / stride_y — points to `[i64; rank]` on the host stack,
// the per-axis SIGNED element stride for input
// and output respectively. Stride 0 marks a
// broadcast axis. Stride may be negative.
// diagonal — i32 mask offset (same semantics as contig).
// stream — cudaStream_t cast to `*mut c_void`.
//
// Status codes mirror the GEMM family (see crate-level doc).
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Triu strided, f16.
pub fn baracuda_kernels_triu_f16_strided_run(
input: *const c_void,
output: *mut c_void,
shape: *const i32,
rank: i32,
stride_x: *const i64,
stride_y: *const i64,
diagonal: i32,
stream: *mut c_void,
) -> i32;
/// Implementability check for `triu_f16_strided`.
pub fn baracuda_kernels_triu_f16_strided_can_implement(
input: *const c_void,
output: *const c_void,
shape: *const i32,
rank: i32,
stride_x: *const i64,
stride_y: *const i64,
diagonal: i32,
) -> i32;
/// Triu strided, bf16.
pub fn baracuda_kernels_triu_bf16_strided_run(
input: *const c_void,
output: *mut c_void,
shape: *const i32,
rank: i32,
stride_x: *const i64,
stride_y: *const i64,
diagonal: i32,
stream: *mut c_void,
) -> i32;
/// Implementability check for `triu_bf16_strided`.
pub fn baracuda_kernels_triu_bf16_strided_can_implement(
input: *const c_void,
output: *const c_void,
shape: *const i32,
rank: i32,
stride_x: *const i64,
stride_y: *const i64,
diagonal: i32,
) -> i32;
/// Triu strided, f32.
pub fn baracuda_kernels_triu_f32_strided_run(
input: *const c_void,
output: *mut c_void,
shape: *const i32,
rank: i32,
stride_x: *const i64,
stride_y: *const i64,
diagonal: i32,
stream: *mut c_void,
) -> i32;
/// Implementability check for `triu_f32_strided`.
pub fn baracuda_kernels_triu_f32_strided_can_implement(
input: *const c_void,
output: *const c_void,
shape: *const i32,
rank: i32,
stride_x: *const i64,
stride_y: *const i64,
diagonal: i32,
) -> i32;
/// Triu strided, f64.
pub fn baracuda_kernels_triu_f64_strided_run(
input: *const c_void,
output: *mut c_void,
shape: *const i32,
rank: i32,
stride_x: *const i64,
stride_y: *const i64,
diagonal: i32,
stream: *mut c_void,
) -> i32;
/// Implementability check for `triu_f64_strided`.
pub fn baracuda_kernels_triu_f64_strided_can_implement(
input: *const c_void,
output: *const c_void,
shape: *const i32,
rank: i32,
stride_x: *const i64,
stride_y: *const i64,
diagonal: i32,
) -> i32;
/// Triu strided, i32.
pub fn baracuda_kernels_triu_i32_strided_run(
input: *const c_void,
output: *mut c_void,
shape: *const i32,
rank: i32,
stride_x: *const i64,
stride_y: *const i64,
diagonal: i32,
stream: *mut c_void,
) -> i32;
/// Implementability check for `triu_i32_strided`.
pub fn baracuda_kernels_triu_i32_strided_can_implement(
input: *const c_void,
output: *const c_void,
shape: *const i32,
rank: i32,
stride_x: *const i64,
stride_y: *const i64,
diagonal: i32,
) -> i32;
/// Triu strided, i64.
pub fn baracuda_kernels_triu_i64_strided_run(
input: *const c_void,
output: *mut c_void,
shape: *const i32,
rank: i32,
stride_x: *const i64,
stride_y: *const i64,
diagonal: i32,
stream: *mut c_void,
) -> i32;
/// Implementability check for `triu_i64_strided`.
pub fn baracuda_kernels_triu_i64_strided_can_implement(
input: *const c_void,
output: *const c_void,
shape: *const i32,
rank: i32,
stride_x: *const i64,
stride_y: *const i64,
diagonal: i32,
) -> i32;
/// Triu strided, Bool (storage = u8).
pub fn baracuda_kernels_triu_bool_strided_run(
input: *const c_void,
output: *mut c_void,
shape: *const i32,
rank: i32,
stride_x: *const i64,
stride_y: *const i64,
diagonal: i32,
stream: *mut c_void,
) -> i32;
/// Implementability check for `triu_bool_strided`.
pub fn baracuda_kernels_triu_bool_strided_can_implement(
input: *const c_void,
output: *const c_void,
shape: *const i32,
rank: i32,
stride_x: *const i64,
stride_y: *const i64,
diagonal: i32,
) -> i32;
/// Tril strided, f16.
pub fn baracuda_kernels_tril_f16_strided_run(
input: *const c_void,
output: *mut c_void,
shape: *const i32,
rank: i32,
stride_x: *const i64,
stride_y: *const i64,
diagonal: i32,
stream: *mut c_void,
) -> i32;
/// Implementability check for `tril_f16_strided`.
pub fn baracuda_kernels_tril_f16_strided_can_implement(
input: *const c_void,
output: *const c_void,
shape: *const i32,
rank: i32,
stride_x: *const i64,
stride_y: *const i64,
diagonal: i32,
) -> i32;
/// Tril strided, bf16.
pub fn baracuda_kernels_tril_bf16_strided_run(
input: *const c_void,
output: *mut c_void,
shape: *const i32,
rank: i32,
stride_x: *const i64,
stride_y: *const i64,
diagonal: i32,
stream: *mut c_void,
) -> i32;
/// Implementability check for `tril_bf16_strided`.
pub fn baracuda_kernels_tril_bf16_strided_can_implement(
input: *const c_void,
output: *const c_void,
shape: *const i32,
rank: i32,
stride_x: *const i64,
stride_y: *const i64,
diagonal: i32,
) -> i32;
/// Tril strided, f32.
pub fn baracuda_kernels_tril_f32_strided_run(
input: *const c_void,
output: *mut c_void,
shape: *const i32,
rank: i32,
stride_x: *const i64,
stride_y: *const i64,
diagonal: i32,
stream: *mut c_void,
) -> i32;
/// Implementability check for `tril_f32_strided`.
pub fn baracuda_kernels_tril_f32_strided_can_implement(
input: *const c_void,
output: *const c_void,
shape: *const i32,
rank: i32,
stride_x: *const i64,
stride_y: *const i64,
diagonal: i32,
) -> i32;
/// Tril strided, f64.
pub fn baracuda_kernels_tril_f64_strided_run(
input: *const c_void,
output: *mut c_void,
shape: *const i32,
rank: i32,
stride_x: *const i64,
stride_y: *const i64,
diagonal: i32,
stream: *mut c_void,
) -> i32;
/// Implementability check for `tril_f64_strided`.
pub fn baracuda_kernels_tril_f64_strided_can_implement(
input: *const c_void,
output: *const c_void,
shape: *const i32,
rank: i32,
stride_x: *const i64,
stride_y: *const i64,
diagonal: i32,
) -> i32;
/// Tril strided, i32.
pub fn baracuda_kernels_tril_i32_strided_run(
input: *const c_void,
output: *mut c_void,
shape: *const i32,
rank: i32,
stride_x: *const i64,
stride_y: *const i64,
diagonal: i32,
stream: *mut c_void,
) -> i32;
/// Implementability check for `tril_i32_strided`.
pub fn baracuda_kernels_tril_i32_strided_can_implement(
input: *const c_void,
output: *const c_void,
shape: *const i32,
rank: i32,
stride_x: *const i64,
stride_y: *const i64,
diagonal: i32,
) -> i32;
/// Tril strided, i64.
pub fn baracuda_kernels_tril_i64_strided_run(
input: *const c_void,
output: *mut c_void,
shape: *const i32,
rank: i32,
stride_x: *const i64,
stride_y: *const i64,
diagonal: i32,
stream: *mut c_void,
) -> i32;
/// Implementability check for `tril_i64_strided`.
pub fn baracuda_kernels_tril_i64_strided_can_implement(
input: *const c_void,
output: *const c_void,
shape: *const i32,
rank: i32,
stride_x: *const i64,
stride_y: *const i64,
diagonal: i32,
) -> i32;
/// Tril strided, Bool (storage = u8).
pub fn baracuda_kernels_tril_bool_strided_run(
input: *const c_void,
output: *mut c_void,
shape: *const i32,
rank: i32,
stride_x: *const i64,
stride_y: *const i64,
diagonal: i32,
stream: *mut c_void,
) -> i32;
/// Implementability check for `tril_bool_strided`.
pub fn baracuda_kernels_tril_bool_strided_can_implement(
input: *const c_void,
output: *const c_void,
shape: *const i32,
rank: i32,
stride_x: *const i64,
stride_y: *const i64,
diagonal: i32,
) -> i32;
}
// ============================================================================
// Shape / layout — Repeat (Category N, per-axis tile)
// ============================================================================
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Repeat (per-axis tile), f32. `output.shape[d] =
/// input.shape[d] * repeats[d]`. Kernel computes
/// `input_coord[d] = output_coord[d] % input.shape[d]`.
pub fn baracuda_kernels_repeat_f32_run(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `repeat_f32`.
pub fn baracuda_kernels_repeat_f32_can_implement(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Repeat (per-axis tile), f16. Same parameter shape as the f32
/// variant — pure copy, no arithmetic.
pub fn baracuda_kernels_repeat_f16_run(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `repeat_f16`.
pub fn baracuda_kernels_repeat_f16_can_implement(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Repeat (per-axis tile), bf16.
pub fn baracuda_kernels_repeat_bf16_run(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `repeat_bf16`.
pub fn baracuda_kernels_repeat_bf16_can_implement(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Repeat (per-axis tile), f64.
pub fn baracuda_kernels_repeat_f64_run(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `repeat_f64`.
pub fn baracuda_kernels_repeat_f64_can_implement(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ============================================================================
// Reductions — variance / std-dev (Phase 4 — Welford one-pass)
// ============================================================================
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Variance reduction along one axis, f32, Welford one-pass.
/// `correction = 1` for Bessel-corrected sample variance, 0 for
/// population variance.
pub fn baracuda_kernels_reduce_var_f32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_var_f32`.
pub fn baracuda_kernels_reduce_var_f32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Std-dev along one axis, f32, Welford + sqrt.
pub fn baracuda_kernels_reduce_std_f32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_std_f32`.
pub fn baracuda_kernels_reduce_std_f32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
x: *const c_void,
y: *const c_void,
) -> i32;
// ---- Var / Std FW dtype fanout (Phase 4 deferral 4.2 close-out) ----
// Welford state runs at the `WelfordAcc<T>` precision: f32 for
// f16/bf16/f32 inputs (the f16/bf16 detour through f32 at load /
// store time), f64 for f64 inputs. ABI identical to the f32 variant.
/// Variance reduction along one axis, f16.
pub fn baracuda_kernels_reduce_var_f16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_var_f16`.
pub fn baracuda_kernels_reduce_var_f16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Std-dev along one axis, f16.
pub fn baracuda_kernels_reduce_std_f16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_std_f16`.
pub fn baracuda_kernels_reduce_std_f16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Variance reduction along one axis, bf16.
pub fn baracuda_kernels_reduce_var_bf16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_var_bf16`.
pub fn baracuda_kernels_reduce_var_bf16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Std-dev along one axis, bf16.
pub fn baracuda_kernels_reduce_std_bf16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_std_bf16`.
pub fn baracuda_kernels_reduce_std_bf16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Variance reduction along one axis, f64 (Welford in f64).
pub fn baracuda_kernels_reduce_var_f64_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_var_f64`.
pub fn baracuda_kernels_reduce_var_f64_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Std-dev along one axis, f64 (Welford in f64 + sqrt).
pub fn baracuda_kernels_reduce_std_f64_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_std_f64`.
pub fn baracuda_kernels_reduce_std_f64_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ============================================================================
// Reductions — argmax / argmin (Phase 4 — i64 index output)
// ============================================================================
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// `argmax(x, axis=k)`, f32 input, i64 output. Ties broken by
/// first occurrence (smallest index wins).
pub fn baracuda_kernels_arg_reduce_argmax_f32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmax_f32`.
pub fn baracuda_kernels_arg_reduce_argmax_f32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `argmin(x, axis=k)`, f32 input, i64 output.
pub fn baracuda_kernels_arg_reduce_argmin_f32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmin_f32`.
pub fn baracuda_kernels_arg_reduce_argmin_f32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `argmax(x, axis=k)`, f16 input, i64 output.
pub fn baracuda_kernels_arg_reduce_argmax_f16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmax_f16`.
pub fn baracuda_kernels_arg_reduce_argmax_f16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `argmin(x, axis=k)`, f16 input, i64 output.
pub fn baracuda_kernels_arg_reduce_argmin_f16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmin_f16`.
pub fn baracuda_kernels_arg_reduce_argmin_f16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `argmax(x, axis=k)`, bf16 input, i64 output.
pub fn baracuda_kernels_arg_reduce_argmax_bf16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmax_bf16`.
pub fn baracuda_kernels_arg_reduce_argmax_bf16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `argmin(x, axis=k)`, bf16 input, i64 output.
pub fn baracuda_kernels_arg_reduce_argmin_bf16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmin_bf16`.
pub fn baracuda_kernels_arg_reduce_argmin_bf16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `argmax(x, axis=k)`, f64 input, i64 output.
pub fn baracuda_kernels_arg_reduce_argmax_f64_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmax_f64`.
pub fn baracuda_kernels_arg_reduce_argmax_f64_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `argmin(x, axis=k)`, f64 input, i64 output.
pub fn baracuda_kernels_arg_reduce_argmin_f64_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmin_f64`.
pub fn baracuda_kernels_arg_reduce_argmin_f64_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
// -------------------------------------------------------------------------
// Phase 12.2 — u32 output variants.
// -------------------------------------------------------------------------
/// `argmax(x, axis=k)`, f32 input, u32 output.
pub fn baracuda_kernels_arg_reduce_argmax_f32_u32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmax_f32_u32`.
pub fn baracuda_kernels_arg_reduce_argmax_f32_u32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `argmin(x, axis=k)`, f32 input, u32 output.
pub fn baracuda_kernels_arg_reduce_argmin_f32_u32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmin_f32_u32`.
pub fn baracuda_kernels_arg_reduce_argmin_f32_u32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `argmax(x, axis=k)`, f16 input, u32 output.
pub fn baracuda_kernels_arg_reduce_argmax_f16_u32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmax_f16_u32`.
pub fn baracuda_kernels_arg_reduce_argmax_f16_u32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `argmin(x, axis=k)`, f16 input, u32 output.
pub fn baracuda_kernels_arg_reduce_argmin_f16_u32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmin_f16_u32`.
pub fn baracuda_kernels_arg_reduce_argmin_f16_u32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `argmax(x, axis=k)`, bf16 input, u32 output.
pub fn baracuda_kernels_arg_reduce_argmax_bf16_u32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmax_bf16_u32`.
pub fn baracuda_kernels_arg_reduce_argmax_bf16_u32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `argmin(x, axis=k)`, bf16 input, u32 output.
pub fn baracuda_kernels_arg_reduce_argmin_bf16_u32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmin_bf16_u32`.
pub fn baracuda_kernels_arg_reduce_argmin_bf16_u32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `argmax(x, axis=k)`, f64 input, u32 output.
pub fn baracuda_kernels_arg_reduce_argmax_f64_u32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmax_f64_u32`.
pub fn baracuda_kernels_arg_reduce_argmax_f64_u32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `argmin(x, axis=k)`, f64 input, u32 output.
pub fn baracuda_kernels_arg_reduce_argmin_f64_u32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmin_f64_u32`.
pub fn baracuda_kernels_arg_reduce_argmin_f64_u32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
// -------------------------------------------------------------------------
// Phase 12.2 — i32 output variants.
// -------------------------------------------------------------------------
/// `argmax(x, axis=k)`, f32 input, i32 output.
pub fn baracuda_kernels_arg_reduce_argmax_f32_i32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmax_f32_i32`.
pub fn baracuda_kernels_arg_reduce_argmax_f32_i32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `argmin(x, axis=k)`, f32 input, i32 output.
pub fn baracuda_kernels_arg_reduce_argmin_f32_i32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmin_f32_i32`.
pub fn baracuda_kernels_arg_reduce_argmin_f32_i32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `argmax(x, axis=k)`, f16 input, i32 output.
pub fn baracuda_kernels_arg_reduce_argmax_f16_i32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmax_f16_i32`.
pub fn baracuda_kernels_arg_reduce_argmax_f16_i32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `argmin(x, axis=k)`, f16 input, i32 output.
pub fn baracuda_kernels_arg_reduce_argmin_f16_i32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmin_f16_i32`.
pub fn baracuda_kernels_arg_reduce_argmin_f16_i32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `argmax(x, axis=k)`, bf16 input, i32 output.
pub fn baracuda_kernels_arg_reduce_argmax_bf16_i32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmax_bf16_i32`.
pub fn baracuda_kernels_arg_reduce_argmax_bf16_i32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `argmin(x, axis=k)`, bf16 input, i32 output.
pub fn baracuda_kernels_arg_reduce_argmin_bf16_i32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmin_bf16_i32`.
pub fn baracuda_kernels_arg_reduce_argmin_bf16_i32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `argmax(x, axis=k)`, f64 input, i32 output.
pub fn baracuda_kernels_arg_reduce_argmax_f64_i32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmax_f64_i32`.
pub fn baracuda_kernels_arg_reduce_argmax_f64_i32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `argmin(x, axis=k)`, f64 input, i32 output.
pub fn baracuda_kernels_arg_reduce_argmin_f64_i32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmin_f64_i32`.
pub fn baracuda_kernels_arg_reduce_argmin_f64_i32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
// ====================================================================
// Phase 37 Gap 1b — integer-dtype argmin / argmax.
//
// Coverage: 6 input dtypes × 2 ops × 2 idx dtypes = 24 SKUs.
// Ties broken by FIRST occurrence (smallest index wins) — same as
// the FP family. Idx-dtype suffix is explicit on every symbol
// (`_i32` or `_i64`) — no implicit-i64-default carve-out for
// integer inputs.
// ====================================================================
// ---- i32 idx output ----
/// `argmax(x, axis=k)` u8 input, i32 idx output.
pub fn baracuda_kernels_arg_reduce_argmax_u8_i32_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmax_u8_i32`.
pub fn baracuda_kernels_arg_reduce_argmax_u8_i32_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `argmin(x, axis=k)` u8 input, i32 idx output.
pub fn baracuda_kernels_arg_reduce_argmin_u8_i32_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmin_u8_i32`.
pub fn baracuda_kernels_arg_reduce_argmin_u8_i32_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `argmax(x, axis=k)` i8 input, i32 idx output.
pub fn baracuda_kernels_arg_reduce_argmax_i8_i32_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmax_i8_i32`.
pub fn baracuda_kernels_arg_reduce_argmax_i8_i32_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `argmin(x, axis=k)` i8 input, i32 idx output.
pub fn baracuda_kernels_arg_reduce_argmin_i8_i32_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmin_i8_i32`.
pub fn baracuda_kernels_arg_reduce_argmin_i8_i32_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `argmax(x, axis=k)` u32 input, i32 idx output.
pub fn baracuda_kernels_arg_reduce_argmax_u32_i32_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmax_u32_i32`.
pub fn baracuda_kernels_arg_reduce_argmax_u32_i32_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `argmin(x, axis=k)` u32 input, i32 idx output.
pub fn baracuda_kernels_arg_reduce_argmin_u32_i32_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmin_u32_i32`.
pub fn baracuda_kernels_arg_reduce_argmin_u32_i32_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `argmax(x, axis=k)` i16 input, i32 idx output.
pub fn baracuda_kernels_arg_reduce_argmax_i16_i32_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmax_i16_i32`.
pub fn baracuda_kernels_arg_reduce_argmax_i16_i32_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `argmin(x, axis=k)` i16 input, i32 idx output.
pub fn baracuda_kernels_arg_reduce_argmin_i16_i32_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmin_i16_i32`.
pub fn baracuda_kernels_arg_reduce_argmin_i16_i32_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `argmax(x, axis=k)` i32 input, i32 idx output.
pub fn baracuda_kernels_arg_reduce_argmax_i32_i32_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmax_i32_i32`.
pub fn baracuda_kernels_arg_reduce_argmax_i32_i32_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `argmin(x, axis=k)` i32 input, i32 idx output.
pub fn baracuda_kernels_arg_reduce_argmin_i32_i32_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmin_i32_i32`.
pub fn baracuda_kernels_arg_reduce_argmin_i32_i32_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `argmax(x, axis=k)` i64 input, i32 idx output.
pub fn baracuda_kernels_arg_reduce_argmax_i64_i32_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmax_i64_i32`.
pub fn baracuda_kernels_arg_reduce_argmax_i64_i32_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `argmin(x, axis=k)` i64 input, i32 idx output.
pub fn baracuda_kernels_arg_reduce_argmin_i64_i32_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmin_i64_i32`.
pub fn baracuda_kernels_arg_reduce_argmin_i64_i32_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
// ---- i64 idx output ----
/// `argmax(x, axis=k)` u8 input, i64 idx output.
pub fn baracuda_kernels_arg_reduce_argmax_u8_i64_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmax_u8_i64`.
pub fn baracuda_kernels_arg_reduce_argmax_u8_i64_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `argmin(x, axis=k)` u8 input, i64 idx output.
pub fn baracuda_kernels_arg_reduce_argmin_u8_i64_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmin_u8_i64`.
pub fn baracuda_kernels_arg_reduce_argmin_u8_i64_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `argmax(x, axis=k)` i8 input, i64 idx output.
pub fn baracuda_kernels_arg_reduce_argmax_i8_i64_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmax_i8_i64`.
pub fn baracuda_kernels_arg_reduce_argmax_i8_i64_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `argmin(x, axis=k)` i8 input, i64 idx output.
pub fn baracuda_kernels_arg_reduce_argmin_i8_i64_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmin_i8_i64`.
pub fn baracuda_kernels_arg_reduce_argmin_i8_i64_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `argmax(x, axis=k)` u32 input, i64 idx output.
pub fn baracuda_kernels_arg_reduce_argmax_u32_i64_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmax_u32_i64`.
pub fn baracuda_kernels_arg_reduce_argmax_u32_i64_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `argmin(x, axis=k)` u32 input, i64 idx output.
pub fn baracuda_kernels_arg_reduce_argmin_u32_i64_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmin_u32_i64`.
pub fn baracuda_kernels_arg_reduce_argmin_u32_i64_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `argmax(x, axis=k)` i16 input, i64 idx output.
pub fn baracuda_kernels_arg_reduce_argmax_i16_i64_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmax_i16_i64`.
pub fn baracuda_kernels_arg_reduce_argmax_i16_i64_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `argmin(x, axis=k)` i16 input, i64 idx output.
pub fn baracuda_kernels_arg_reduce_argmin_i16_i64_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmin_i16_i64`.
pub fn baracuda_kernels_arg_reduce_argmin_i16_i64_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `argmax(x, axis=k)` i32 input, i64 idx output.
pub fn baracuda_kernels_arg_reduce_argmax_i32_i64_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmax_i32_i64`.
pub fn baracuda_kernels_arg_reduce_argmax_i32_i64_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `argmin(x, axis=k)` i32 input, i64 idx output.
pub fn baracuda_kernels_arg_reduce_argmin_i32_i64_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmin_i32_i64`.
pub fn baracuda_kernels_arg_reduce_argmin_i32_i64_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `argmax(x, axis=k)` i64 input, i64 idx output.
pub fn baracuda_kernels_arg_reduce_argmax_i64_i64_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmax_i64_i64`.
pub fn baracuda_kernels_arg_reduce_argmax_i64_i64_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// `argmin(x, axis=k)` i64 input, i64 idx output.
pub fn baracuda_kernels_arg_reduce_argmin_i64_i64_run(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `arg_reduce_argmin_i64_i64`.
pub fn baracuda_kernels_arg_reduce_argmin_i64_i64_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
stride_x: *const i64, stride_y: *const i64,
reduce_axis: i32, reduce_extent: i32, reduce_stride_x: i64,
x: *const c_void, y: *const c_void,
) -> i32;
}
// ============================================================================
// Reductions — any / all / count_nonzero (Phase 4 deferral 4.4 — heterogeneous
// output dtype: Any / All → uint8_t Bool output; CountNonzero → int64_t output)
// ============================================================================
//
// Parameter shape mirrors the simple-reduce family (same ABI as
// `baracuda_kernels_reduce_sum_f32_run`); only the output dtype is
// fixed by the symbol. Wired matrix per op:
// {Any, All, CountNonzero} × {f32, f16, bf16, f64, i32, i64, Bool}
// = 21 SKUs total.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// `any(x, axis=k)` with f32 input, uint8_t Bool output.
pub fn baracuda_kernels_reduce_any_f32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_any_f32`.
pub fn baracuda_kernels_reduce_any_f32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `any(x, axis=k)` with f16 input, uint8_t Bool output.
pub fn baracuda_kernels_reduce_any_f16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_any_f16`.
pub fn baracuda_kernels_reduce_any_f16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `any(x, axis=k)` with bf16 input, uint8_t Bool output.
pub fn baracuda_kernels_reduce_any_bf16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_any_bf16`.
pub fn baracuda_kernels_reduce_any_bf16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `any(x, axis=k)` with f64 input, uint8_t Bool output.
pub fn baracuda_kernels_reduce_any_f64_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_any_f64`.
pub fn baracuda_kernels_reduce_any_f64_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `any(x, axis=k)` with i32 input, uint8_t Bool output.
pub fn baracuda_kernels_reduce_any_i32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_any_i32`.
pub fn baracuda_kernels_reduce_any_i32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `any(x, axis=k)` with i64 input, uint8_t Bool output.
pub fn baracuda_kernels_reduce_any_i64_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_any_i64`.
pub fn baracuda_kernels_reduce_any_i64_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `any(x, axis=k)` with Bool (uint8_t) input, uint8_t Bool output.
pub fn baracuda_kernels_reduce_any_bool_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_any_bool`.
pub fn baracuda_kernels_reduce_any_bool_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `all(x, axis=k)` with f32 input, uint8_t Bool output.
pub fn baracuda_kernels_reduce_all_f32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_all_f32`.
pub fn baracuda_kernels_reduce_all_f32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `all(x, axis=k)` with f16 input, uint8_t Bool output.
pub fn baracuda_kernels_reduce_all_f16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_all_f16`.
pub fn baracuda_kernels_reduce_all_f16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `all(x, axis=k)` with bf16 input, uint8_t Bool output.
pub fn baracuda_kernels_reduce_all_bf16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_all_bf16`.
pub fn baracuda_kernels_reduce_all_bf16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `all(x, axis=k)` with f64 input, uint8_t Bool output.
pub fn baracuda_kernels_reduce_all_f64_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_all_f64`.
pub fn baracuda_kernels_reduce_all_f64_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `all(x, axis=k)` with i32 input, uint8_t Bool output.
pub fn baracuda_kernels_reduce_all_i32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_all_i32`.
pub fn baracuda_kernels_reduce_all_i32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `all(x, axis=k)` with i64 input, uint8_t Bool output.
pub fn baracuda_kernels_reduce_all_i64_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_all_i64`.
pub fn baracuda_kernels_reduce_all_i64_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `all(x, axis=k)` with Bool (uint8_t) input, uint8_t Bool output.
pub fn baracuda_kernels_reduce_all_bool_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_all_bool`.
pub fn baracuda_kernels_reduce_all_bool_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `count_nonzero(x, axis=k)` with f32 input, i64 output.
pub fn baracuda_kernels_reduce_count_nonzero_f32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_count_nonzero_f32`.
pub fn baracuda_kernels_reduce_count_nonzero_f32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `count_nonzero(x, axis=k)` with f16 input, i64 output.
pub fn baracuda_kernels_reduce_count_nonzero_f16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_count_nonzero_f16`.
pub fn baracuda_kernels_reduce_count_nonzero_f16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `count_nonzero(x, axis=k)` with bf16 input, i64 output.
pub fn baracuda_kernels_reduce_count_nonzero_bf16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_count_nonzero_bf16`.
pub fn baracuda_kernels_reduce_count_nonzero_bf16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `count_nonzero(x, axis=k)` with f64 input, i64 output.
pub fn baracuda_kernels_reduce_count_nonzero_f64_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_count_nonzero_f64`.
pub fn baracuda_kernels_reduce_count_nonzero_f64_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `count_nonzero(x, axis=k)` with i32 input, i64 output.
pub fn baracuda_kernels_reduce_count_nonzero_i32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_count_nonzero_i32`.
pub fn baracuda_kernels_reduce_count_nonzero_i32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `count_nonzero(x, axis=k)` with i64 input, i64 output.
pub fn baracuda_kernels_reduce_count_nonzero_i64_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_count_nonzero_i64`.
pub fn baracuda_kernels_reduce_count_nonzero_i64_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `count_nonzero(x, axis=k)` with Bool (uint8_t) input, i64 output.
pub fn baracuda_kernels_reduce_count_nonzero_bool_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_count_nonzero_bool`.
pub fn baracuda_kernels_reduce_count_nonzero_bool_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ============================================================================
// Shape / layout — Permute (Category N)
// ============================================================================
//
// `y = x.permute(dims)` — output axis d is input axis `dims[d]`. The
// kernel walks input cells and writes to permuted output positions.
//
// **Aliasing**: NOT in-place safe — same reason as Flip / Roll.
// Permute remaps every cell to a different output coordinate; two
// threads can concurrently touch each cell (one as a read of the
// pre-permute source, another as a write of the post-permute dest).
// Same-pointer aliasing produces data corruption.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Materialized permute, f32.
///
/// **NOT in-place safe** — see family-level aliasing note above.
pub fn baracuda_kernels_permute_f32_run(
input_numel: i64,
rank: i32,
input_shape: *const i32,
dims: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `permute_f32`.
pub fn baracuda_kernels_permute_f32_can_implement(
input_numel: i64,
rank: i32,
input_shape: *const i32,
dims: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Materialized permute, f16. Pure element copy — no math.
pub fn baracuda_kernels_permute_f16_run(
input_numel: i64,
rank: i32,
input_shape: *const i32,
dims: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `permute_f16`.
pub fn baracuda_kernels_permute_f16_can_implement(
input_numel: i64,
rank: i32,
input_shape: *const i32,
dims: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Materialized permute, bf16. Pure element copy — no math.
pub fn baracuda_kernels_permute_bf16_run(
input_numel: i64,
rank: i32,
input_shape: *const i32,
dims: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `permute_bf16`.
pub fn baracuda_kernels_permute_bf16_can_implement(
input_numel: i64,
rank: i32,
input_shape: *const i32,
dims: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Materialized permute, f64. Pure element copy — no math.
pub fn baracuda_kernels_permute_f64_run(
input_numel: i64,
rank: i32,
input_shape: *const i32,
dims: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `permute_f64`.
pub fn baracuda_kernels_permute_f64_can_implement(
input_numel: i64,
rank: i32,
input_shape: *const i32,
dims: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ============================================================================
// Shape / layout — Concat (2-input variant of Category N)
// ============================================================================
//
// `y = cat(a, b, dim=k)` with 2-input arity. Output shape per-axis
// matches a / b except `output[k] = a.shape[k] + b.shape[k]`. Variable-
// arity (N inputs) is a future plan shape (would need device-side
// packing of N pointers + N stride arrays through kernel param block).
// Today only f32 is wired; f16/bf16/f64 are single-INSTANTIATE fanout.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// `cat(a, b, dim)`, f32, contig output.
///
/// `output_shape` matches a / b shape except `[concat_dim]` =
/// `a.shape[concat_dim] + b.shape[concat_dim]`. `split_offset` is
/// `a.shape[concat_dim]` — the kernel uses it to branch between
/// reading from a or b.
///
/// # Safety
/// All device pointers must remain valid for the launch. Host
/// arrays must remain valid for the host-side launch call.
pub fn baracuda_kernels_concat2_f32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
concat_dim: i32,
split_offset: i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_concat2_f32_can_implement` (baracuda kernels concat2 f32 can implement).
pub fn baracuda_kernels_concat2_f32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
concat_dim: i32,
split_offset: i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// `cat(a, b, dim)`, f16, contig output. See f32 variant.
pub fn baracuda_kernels_concat2_f16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
concat_dim: i32,
split_offset: i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_concat2_f16_can_implement` (baracuda kernels concat2 f16 can implement).
pub fn baracuda_kernels_concat2_f16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
concat_dim: i32,
split_offset: i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// `cat(a, b, dim)`, bf16, contig output. See f32 variant.
pub fn baracuda_kernels_concat2_bf16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
concat_dim: i32,
split_offset: i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_concat2_bf16_can_implement` (baracuda kernels concat2 bf16 can implement).
pub fn baracuda_kernels_concat2_bf16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
concat_dim: i32,
split_offset: i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// `cat(a, b, dim)`, f64, contig output. See f32 variant.
pub fn baracuda_kernels_concat2_f64_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
concat_dim: i32,
split_offset: i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_concat2_f64_can_implement` (baracuda kernels concat2 f64 can implement).
pub fn baracuda_kernels_concat2_f64_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
concat_dim: i32,
split_offset: i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
}
// ============================================================================
// Shape / layout — Pad (Category N trailblazer)
// ============================================================================
//
// `y = pad(x, pad_low, pad_high, value)` over arbitrary-rank tensors.
// Output shape per-axis is `input[d] + pad_low[d] + pad_high[d]`.
// Today only f32 + constant mode is wired; future fanout adds the
// remaining dtypes and pad modes (Reflect, Replicate, Circular).
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Pad with a constant value, f32, contig output.
///
/// `input_shape` / `output_shape` / `pad_low` point to host arrays
/// of at least `rank` elements (i32). `stride_x` / `stride_y` are
/// element-stride arrays (i64). The output is conventionally
/// contiguous but the FFI accepts any stride pattern.
///
/// # Safety
/// All device pointers must remain valid for the duration of the
/// launch. Host pointers must remain valid for the duration of the
/// host-side launch call (the launcher copies them into kernel
/// param-block structs).
pub fn baracuda_kernels_pad_constant_f32_run(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
value: f32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_pad_constant_f32_can_implement` (baracuda kernels pad constant f32 can implement).
pub fn baracuda_kernels_pad_constant_f32_can_implement(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
value: f32,
) -> i32;
/// Pad with a constant value, f16, contig output. The `value`
/// argument carries the `__half` bit pattern as `u16` — Rust callers
/// can produce it via `half::f16::to_bits()`. ABI-compatible because
/// `__half` is a 2-byte `__CUDA_ALIGN__(2)` POD struct passed in the
/// same register slot as `unsigned short`.
pub fn baracuda_kernels_pad_constant_f16_run(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
value: u16,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_pad_constant_f16_can_implement` (baracuda kernels pad constant f16 can implement).
pub fn baracuda_kernels_pad_constant_f16_can_implement(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
value: u16,
) -> i32;
/// Pad with a constant value, bf16, contig output. The `value`
/// argument carries the `__nv_bfloat16` bit pattern as `u16` — Rust
/// callers can produce it via `half::bf16::to_bits()`.
pub fn baracuda_kernels_pad_constant_bf16_run(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
value: u16,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_pad_constant_bf16_can_implement` (baracuda kernels pad constant bf16 can implement).
pub fn baracuda_kernels_pad_constant_bf16_can_implement(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
value: u16,
) -> i32;
/// Pad with a constant value, f64, contig output.
pub fn baracuda_kernels_pad_constant_f64_run(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
value: f64,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_pad_constant_f64_can_implement` (baracuda kernels pad constant f64 can implement).
pub fn baracuda_kernels_pad_constant_f64_can_implement(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
value: f64,
) -> i32;
}
// Pad — Reflect / Replicate / Circular modes. None of these take a
// `value` parameter; the pad-region values come from the input itself
// (mirror, clamp, or cyclic wrap respectively). Parameter shape is
// otherwise identical to the constant-mode launchers.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Pad reflect, f32. Mirror input across the boundary (no edge
/// duplication).
pub fn baracuda_kernels_pad_reflect_f32_run(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_pad_reflect_f32_can_implement` (baracuda kernels pad reflect f32 can implement).
pub fn baracuda_kernels_pad_reflect_f32_can_implement(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Pad reflect, f16.
pub fn baracuda_kernels_pad_reflect_f16_run(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_pad_reflect_f16_can_implement` (baracuda kernels pad reflect f16 can implement).
pub fn baracuda_kernels_pad_reflect_f16_can_implement(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Pad reflect, bf16.
pub fn baracuda_kernels_pad_reflect_bf16_run(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_pad_reflect_bf16_can_implement` (baracuda kernels pad reflect bf16 can implement).
pub fn baracuda_kernels_pad_reflect_bf16_can_implement(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Pad reflect, f64.
pub fn baracuda_kernels_pad_reflect_f64_run(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_pad_reflect_f64_can_implement` (baracuda kernels pad reflect f64 can implement).
pub fn baracuda_kernels_pad_reflect_f64_can_implement(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Pad replicate, f32. Clamp to the edge value of the input.
pub fn baracuda_kernels_pad_replicate_f32_run(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_pad_replicate_f32`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_pad_replicate_f32_can_implement(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Pad replicate, f16.
pub fn baracuda_kernels_pad_replicate_f16_run(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_pad_replicate_f16`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_pad_replicate_f16_can_implement(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Pad replicate, bf16.
pub fn baracuda_kernels_pad_replicate_bf16_run(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_pad_replicate_bf16`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_pad_replicate_bf16_can_implement(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Pad replicate, f64.
pub fn baracuda_kernels_pad_replicate_f64_run(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_pad_replicate_f64`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_pad_replicate_f64_can_implement(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Pad circular, f32. Cyclic wrap from the opposite end of each
/// axis.
pub fn baracuda_kernels_pad_circular_f32_run(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_pad_circular_f32_can_implement` (baracuda kernels pad circular f32 can implement).
pub fn baracuda_kernels_pad_circular_f32_can_implement(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Pad circular, f16.
pub fn baracuda_kernels_pad_circular_f16_run(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_pad_circular_f16_can_implement` (baracuda kernels pad circular f16 can implement).
pub fn baracuda_kernels_pad_circular_f16_can_implement(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Pad circular, bf16.
pub fn baracuda_kernels_pad_circular_bf16_run(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_pad_circular_bf16_can_implement` (baracuda kernels pad circular bf16 can implement).
pub fn baracuda_kernels_pad_circular_bf16_can_implement(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Pad circular, f64.
pub fn baracuda_kernels_pad_circular_f64_run(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_pad_circular_f64_can_implement` (baracuda kernels pad circular f64 can implement).
pub fn baracuda_kernels_pad_circular_f64_can_implement(
output_numel: i64,
rank: i32,
input_shape: *const i32,
output_shape: *const i32,
pad_low: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ============================================================================
// Shape / layout — Pad constant BW (slice)
// ============================================================================
//
// Backward of `y = pad(x, pad_low, pad_high, mode=Constant, value=v)`:
// `dx = dy[pad_low : pad_low + input_shape]` — pure slice. The
// gradient at pad-region cells is identically zero (the forward wrote
// a constant there) and is discarded. Iterates `input_numel` (dx-coord
// space). One launcher per fp dtype.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Pad-constant backward (slice), f32.
pub fn baracuda_kernels_pad_constant_backward_f32_run(
input_numel: i64,
rank: i32,
input_shape: *const i32,
pad_low: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_pad_constant_backward_f32_can_implement` (baracuda kernels pad constant backward f32 can implement).
pub fn baracuda_kernels_pad_constant_backward_f32_can_implement(
input_numel: i64,
rank: i32,
input_shape: *const i32,
pad_low: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *const c_void,
) -> i32;
/// Pad-constant backward (slice), f16.
pub fn baracuda_kernels_pad_constant_backward_f16_run(
input_numel: i64,
rank: i32,
input_shape: *const i32,
pad_low: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_pad_constant_backward_f16_can_implement` (baracuda kernels pad constant backward f16 can implement).
pub fn baracuda_kernels_pad_constant_backward_f16_can_implement(
input_numel: i64,
rank: i32,
input_shape: *const i32,
pad_low: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *const c_void,
) -> i32;
/// Pad-constant backward (slice), bf16.
pub fn baracuda_kernels_pad_constant_backward_bf16_run(
input_numel: i64,
rank: i32,
input_shape: *const i32,
pad_low: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_pad_constant_backward_bf16_can_implement` (baracuda kernels pad constant backward bf16 can implement).
pub fn baracuda_kernels_pad_constant_backward_bf16_can_implement(
input_numel: i64,
rank: i32,
input_shape: *const i32,
pad_low: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *const c_void,
) -> i32;
/// Pad-constant backward (slice), f64.
pub fn baracuda_kernels_pad_constant_backward_f64_run(
input_numel: i64,
rank: i32,
input_shape: *const i32,
pad_low: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_pad_constant_backward_f64_can_implement` (baracuda kernels pad constant backward f64 can implement).
pub fn baracuda_kernels_pad_constant_backward_f64_can_implement(
input_numel: i64,
rank: i32,
input_shape: *const i32,
pad_low: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *const c_void,
) -> i32;
}
// ============================================================================
// Shape / layout — Repeat backward (Category N, gather-adjoint sum)
// ============================================================================
//
// Backward of `y = repeat(x, repeats)`: `dx[c_in] = sum_{k}
// dy[c_in + k * input_shape]` per axis — every dy cell whose
// `c_out[d] mod input_shape[d] == c_in[d]` for all d contributes. One
// thread per dx cell loops the per-axis repeats grid and accumulates;
// f16 / bf16 accumulate in float for numerical stability. Iterates
// `input_numel` (dx-coord space). One launcher per fp dtype.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Repeat backward (gather-adjoint sum), f32.
pub fn baracuda_kernels_repeat_backward_f32_run(
input_numel: i64,
rank: i32,
input_shape: *const i32,
repeats: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_repeat_backward_f32_can_implement` (baracuda kernels repeat backward f32 can implement).
pub fn baracuda_kernels_repeat_backward_f32_can_implement(
input_numel: i64,
rank: i32,
input_shape: *const i32,
repeats: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *const c_void,
) -> i32;
/// Repeat backward (gather-adjoint sum), f16. Accumulates in float.
pub fn baracuda_kernels_repeat_backward_f16_run(
input_numel: i64,
rank: i32,
input_shape: *const i32,
repeats: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_repeat_backward_f16_can_implement` (baracuda kernels repeat backward f16 can implement).
pub fn baracuda_kernels_repeat_backward_f16_can_implement(
input_numel: i64,
rank: i32,
input_shape: *const i32,
repeats: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *const c_void,
) -> i32;
/// Repeat backward (gather-adjoint sum), bf16. Accumulates in float.
pub fn baracuda_kernels_repeat_backward_bf16_run(
input_numel: i64,
rank: i32,
input_shape: *const i32,
repeats: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_repeat_backward_bf16_can_implement` (baracuda kernels repeat backward bf16 can implement).
pub fn baracuda_kernels_repeat_backward_bf16_can_implement(
input_numel: i64,
rank: i32,
input_shape: *const i32,
repeats: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *const c_void,
) -> i32;
/// Repeat backward (gather-adjoint sum), f64.
pub fn baracuda_kernels_repeat_backward_f64_run(
input_numel: i64,
rank: i32,
input_shape: *const i32,
repeats: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_repeat_backward_f64_can_implement` (baracuda kernels repeat backward f64 can implement).
pub fn baracuda_kernels_repeat_backward_f64_can_implement(
input_numel: i64,
rank: i32,
input_shape: *const i32,
repeats: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *const c_void,
) -> i32;
}
// ============================================================================
// Shape / layout — Concat2 backward (Category N, pure slice-split)
// ============================================================================
//
// Backward of `y = cat(a, b, dim=k)`: every dy cell maps to exactly one
// of `da` or `db`. Pure inverse routing — bit-exact across every wired
// dtype, no arithmetic. `da` collects `dy[..., :split_offset, ...]` and
// `db` collects `dy[..., split_offset:, ...]` along `concat_dim`. One
// thread per dy cell. Iterates `output_numel` (= dy.numel()).
// `split_offset` is `a.shape[concat_dim]` from the forward.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Concat2 backward (slice-split), f32. Bit-exact, no arithmetic.
pub fn baracuda_kernels_concat2_backward_f32_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
concat_dim: i32,
split_offset: i32,
stride_dy: *const i64,
stride_da: *const i64,
stride_db: *const i64,
dy: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_concat2_backward_f32_can_implement` (baracuda kernels concat2 backward f32 can implement).
pub fn baracuda_kernels_concat2_backward_f32_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
concat_dim: i32,
split_offset: i32,
stride_dy: *const i64,
stride_da: *const i64,
stride_db: *const i64,
dy: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Concat2 backward (slice-split), f16. See f32 variant.
pub fn baracuda_kernels_concat2_backward_f16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
concat_dim: i32,
split_offset: i32,
stride_dy: *const i64,
stride_da: *const i64,
stride_db: *const i64,
dy: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_concat2_backward_f16_can_implement` (baracuda kernels concat2 backward f16 can implement).
pub fn baracuda_kernels_concat2_backward_f16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
concat_dim: i32,
split_offset: i32,
stride_dy: *const i64,
stride_da: *const i64,
stride_db: *const i64,
dy: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Concat2 backward (slice-split), bf16. See f32 variant.
pub fn baracuda_kernels_concat2_backward_bf16_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
concat_dim: i32,
split_offset: i32,
stride_dy: *const i64,
stride_da: *const i64,
stride_db: *const i64,
dy: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_concat2_backward_bf16_can_implement` (baracuda kernels concat2 backward bf16 can implement).
pub fn baracuda_kernels_concat2_backward_bf16_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
concat_dim: i32,
split_offset: i32,
stride_dy: *const i64,
stride_da: *const i64,
stride_db: *const i64,
dy: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Concat2 backward (slice-split), f64. See f32 variant.
pub fn baracuda_kernels_concat2_backward_f64_run(
output_numel: i64,
rank: i32,
output_shape: *const i32,
concat_dim: i32,
split_offset: i32,
stride_dy: *const i64,
stride_da: *const i64,
stride_db: *const i64,
dy: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_concat2_backward_f64_can_implement` (baracuda kernels concat2 backward f64 can implement).
pub fn baracuda_kernels_concat2_backward_f64_can_implement(
output_numel: i64,
rank: i32,
output_shape: *const i32,
concat_dim: i32,
split_offset: i32,
stride_dy: *const i64,
stride_da: *const i64,
stride_db: *const i64,
dy: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
}
// ============================================================================
// Elementwise — scaled ternary (3→1 ops with a scalar parameter)
// ============================================================================
//
// `Addcmul` and `Addcdiv` follow PyTorch's `torch.addcmul(c, a, b, value=k)`
// / `torch.addcdiv(c, a, b, value=k)` semantics:
// addcmul: y = a + scale * b * c
// addcdiv: y = a + scale * (b / c)
//
// FFI signature mirrors the unparameterized ternary launchers but
// inserts a `float scale` parameter between the y pointer and the
// workspace pointer. The Rust dispatcher reads `desc.scale` and
// forwards it; unparameterized ternary ops (Clamp, Fma) take a separate
// FFI without the scale arg (above).
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
// ---- addcmul ----
/// `y = a + scale * b * c`, f32, contig fast path.
pub fn baracuda_kernels_ternary_addcmul_f32_run(
numel: i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *mut c_void,
scale: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `addcmul_f32`.
pub fn baracuda_kernels_ternary_addcmul_f32_can_implement(
numel: i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *const c_void,
) -> i32;
/// `y = a + scale * b * c`, f32, strided / broadcast path.
pub fn baracuda_kernels_ternary_addcmul_f32_strided_run(
numel: i64, rank: i32,
shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_c: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *mut c_void,
scale: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `addcmul_f32_strided`.
pub fn baracuda_kernels_ternary_addcmul_f32_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64,
stride_c: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *const c_void,
scale: f32,
) -> i32;
/// `addcmul`, f16, contig.
pub fn baracuda_kernels_ternary_addcmul_f16_run(
numel: i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *mut c_void,
scale: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `addcmul_f16`.
pub fn baracuda_kernels_ternary_addcmul_f16_can_implement(
numel: i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *const c_void,
) -> i32;
/// `addcmul`, f16, strided.
pub fn baracuda_kernels_ternary_addcmul_f16_strided_run(
numel: i64, rank: i32,
shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_c: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *mut c_void,
scale: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `addcmul_f16_strided`.
pub fn baracuda_kernels_ternary_addcmul_f16_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64,
stride_c: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *const c_void,
scale: f32,
) -> i32;
/// `addcmul`, bf16, contig.
pub fn baracuda_kernels_ternary_addcmul_bf16_run(
numel: i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *mut c_void,
scale: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `addcmul_bf16`.
pub fn baracuda_kernels_ternary_addcmul_bf16_can_implement(
numel: i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *const c_void,
) -> i32;
/// `addcmul`, bf16, strided.
pub fn baracuda_kernels_ternary_addcmul_bf16_strided_run(
numel: i64, rank: i32,
shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_c: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *mut c_void,
scale: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `addcmul_bf16_strided`.
pub fn baracuda_kernels_ternary_addcmul_bf16_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64,
stride_c: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *const c_void,
scale: f32,
) -> i32;
/// `addcmul`, f64, contig.
pub fn baracuda_kernels_ternary_addcmul_f64_run(
numel: i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *mut c_void,
scale: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `addcmul_f64`.
pub fn baracuda_kernels_ternary_addcmul_f64_can_implement(
numel: i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *const c_void,
) -> i32;
/// `addcmul`, f64, strided.
pub fn baracuda_kernels_ternary_addcmul_f64_strided_run(
numel: i64, rank: i32,
shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_c: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *mut c_void,
scale: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `addcmul_f64_strided`.
pub fn baracuda_kernels_ternary_addcmul_f64_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64,
stride_c: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *const c_void,
scale: f32,
) -> i32;
// ---- addcdiv ----
/// `y = a + scale * (b / c)`, f32, contig.
pub fn baracuda_kernels_ternary_addcdiv_f32_run(
numel: i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *mut c_void,
scale: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `addcdiv_f32`.
pub fn baracuda_kernels_ternary_addcdiv_f32_can_implement(
numel: i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *const c_void,
) -> i32;
/// `addcdiv`, f32, strided.
pub fn baracuda_kernels_ternary_addcdiv_f32_strided_run(
numel: i64, rank: i32,
shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_c: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *mut c_void,
scale: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `addcdiv_f32_strided`.
pub fn baracuda_kernels_ternary_addcdiv_f32_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64,
stride_c: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *const c_void,
scale: f32,
) -> i32;
/// `addcdiv`, f16, contig.
pub fn baracuda_kernels_ternary_addcdiv_f16_run(
numel: i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *mut c_void,
scale: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `addcdiv_f16`.
pub fn baracuda_kernels_ternary_addcdiv_f16_can_implement(
numel: i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *const c_void,
) -> i32;
/// `addcdiv`, f16, strided.
pub fn baracuda_kernels_ternary_addcdiv_f16_strided_run(
numel: i64, rank: i32,
shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_c: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *mut c_void,
scale: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `addcdiv_f16_strided`.
pub fn baracuda_kernels_ternary_addcdiv_f16_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64,
stride_c: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *const c_void,
scale: f32,
) -> i32;
/// `addcdiv`, bf16, contig.
pub fn baracuda_kernels_ternary_addcdiv_bf16_run(
numel: i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *mut c_void,
scale: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `addcdiv_bf16`.
pub fn baracuda_kernels_ternary_addcdiv_bf16_can_implement(
numel: i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *const c_void,
) -> i32;
/// `addcdiv`, bf16, strided.
pub fn baracuda_kernels_ternary_addcdiv_bf16_strided_run(
numel: i64, rank: i32,
shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_c: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *mut c_void,
scale: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `addcdiv_bf16_strided`.
pub fn baracuda_kernels_ternary_addcdiv_bf16_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64,
stride_c: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *const c_void,
scale: f32,
) -> i32;
/// `addcdiv`, f64, contig.
pub fn baracuda_kernels_ternary_addcdiv_f64_run(
numel: i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *mut c_void,
scale: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `addcdiv_f64`.
pub fn baracuda_kernels_ternary_addcdiv_f64_can_implement(
numel: i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *const c_void,
) -> i32;
/// `addcdiv`, f64, strided.
pub fn baracuda_kernels_ternary_addcdiv_f64_strided_run(
numel: i64, rank: i32,
shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_c: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *mut c_void,
scale: f32,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `addcdiv_f64_strided`.
pub fn baracuda_kernels_ternary_addcdiv_f64_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64,
stride_c: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, c: *const c_void, y: *const c_void,
scale: f32,
) -> i32;
}
// ============================================================================
// Elementwise — where (heterogeneous-dtype ternary, u8 cond + T → T)
// ============================================================================
//
// `y = cond ? a : b` with `cond: TensorRef<u8, N>` (PyTorch / NumPy
// bool storage convention: 0 = false, non-zero = true) and same-dtype
// a / b / y. Distinct family from the homogeneous-dtype ternary path
// above — the cond input has a different dtype than the value inputs,
// so the FFI takes an extra `stride_cond` array on the strided path.
//
// All 4 FP value dtypes wired: {f32, f16, bf16, f64} × {contig,
// strided}.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// `where(cond, a, b)`, f32 values + u8 cond, contig fast path.
/// This is the where-ternary trailblazer — its safety + aliasing
/// contract carries over to every other where-family launcher
/// across all value dtypes and cond-dtype variants (`where_u32cond_*`,
/// `where_i64cond_*`).
///
/// `y = cond ? a : b` elementwise. `cond` is interpreted as bool
/// per PyTorch convention (0 → b, non-zero → a).
///
/// # Safety
/// All device pointers must remain valid for the duration of the
/// launch. `cond` must point to at least `numel` `u8`s; `a`, `b`,
/// `y` to at least `numel` `f32`s.
///
/// **Aliasing (Phase 64)**: aliasing `y` with `a` or `b` (or both,
/// if `a == b`) is safe — the contig kernel evaluates
/// `y[i] = cond[i] ? a[i] : b[i]` with each thread touching only
/// its own index `i` (read `cond[i]` + `a[i]` + `b[i]` before
/// write `y[i]`). Callers implementing `Op::WhereInplace` /
/// conditional gradient masking can dispatch the forward symbol
/// with `a_ptr == y_ptr` (or `b_ptr == y_ptr`) without a dedicated
/// `_inplace_` variant. Aliasing `y` with `cond` requires `cond`
/// and `y` to have the same byte width (the u8-cond variant won't
/// alias against an f32 `y`; the u32-cond / i64-cond variants
/// align with f32 / f64 `y` respectively). This contract is
/// stable across baracuda versions.
pub fn baracuda_kernels_where_f32_run(
numel: i64,
cond: *const c_void,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `where_f32`.
pub fn baracuda_kernels_where_f32_can_implement(
numel: i64,
cond: *const c_void,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, f32 values, strided / broadcast path.
///
/// Each operand has its own stride array — cond can be broadcast
/// independently from a and b (typical use: per-row mask
/// `[M, 1] + [M, N] + [M, N]`).
pub fn baracuda_kernels_where_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_cond: *const i64,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_f32_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
}
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// `where(cond, a, b)`, f16 values + u8 cond, contig fast path.
///
/// `y = cond ? a : b` elementwise. `cond` is interpreted as bool
/// per PyTorch convention (0 → b, non-zero → a).
///
/// # Safety
/// All device pointers must remain valid for the duration of the
/// launch. `cond` must point to at least `numel` `u8`s; `a`, `b`,
/// `y` to at least `numel` `f16`s.
pub fn baracuda_kernels_where_f16_run(
numel: i64,
cond: *const c_void,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `where_f16`.
pub fn baracuda_kernels_where_f16_can_implement(
numel: i64,
cond: *const c_void,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, f16 values, strided / broadcast path.
///
/// Each operand has its own stride array — cond can be broadcast
/// independently from a and b (typical use: per-row mask
/// `[M, 1] + [M, N] + [M, N]`).
pub fn baracuda_kernels_where_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_cond: *const i64,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_f16_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
}
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// `where(cond, a, b)`, bf16 values + u8 cond, contig fast path.
///
/// `y = cond ? a : b` elementwise. `cond` is interpreted as bool
/// per PyTorch convention (0 → b, non-zero → a).
///
/// # Safety
/// All device pointers must remain valid for the duration of the
/// launch. `cond` must point to at least `numel` `u8`s; `a`, `b`,
/// `y` to at least `numel` `bf16`s.
pub fn baracuda_kernels_where_bf16_run(
numel: i64,
cond: *const c_void,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `where_bf16`.
pub fn baracuda_kernels_where_bf16_can_implement(
numel: i64,
cond: *const c_void,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, bf16 values, strided / broadcast path.
///
/// Each operand has its own stride array — cond can be broadcast
/// independently from a and b (typical use: per-row mask
/// `[M, 1] + [M, N] + [M, N]`).
pub fn baracuda_kernels_where_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_cond: *const i64,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_bf16_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
}
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// `where(cond, a, b)`, f64 values + u8 cond, contig fast path.
///
/// `y = cond ? a : b` elementwise. `cond` is interpreted as bool
/// per PyTorch convention (0 → b, non-zero → a).
///
/// # Safety
/// All device pointers must remain valid for the duration of the
/// launch. `cond` must point to at least `numel` `u8`s; `a`, `b`,
/// `y` to at least `numel` `f64`s.
pub fn baracuda_kernels_where_f64_run(
numel: i64,
cond: *const c_void,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `where_f64`.
pub fn baracuda_kernels_where_f64_can_implement(
numel: i64,
cond: *const c_void,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, f64 values, strided / broadcast path.
///
/// Each operand has its own stride array — cond can be broadcast
/// independently from a and b (typical use: per-row mask
/// `[M, 1] + [M, N] + [M, N]`).
pub fn baracuda_kernels_where_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_cond: *const i64,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_f64_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
}
// ============================================================================
// Elementwise — where dtype matrix fanout (Phase 38 / Fuel 6c.4 Gap 3)
// ============================================================================
//
// Extends the heterogeneous-dtype `where(cond, a, b)` ternary's cond
// dimension to cover `u32` and `i64` (in addition to the original
// `u8`) and the value dimension to cover the integer family
// (`u8` / `i8` / `u32` / `i16` / `i32` / `i64`) and Fp8E4M3 (in
// addition to the original `{f32, f16, bf16, f64}`).
//
// Naming convention: explicit `<cond>cond_<value>` prefix on every
// symbol in this section, e.g. `where_u32cond_f32_run`. The original
// `where_<value>_run` family above stays in place — those symbols
// implicitly mean "u8 cond" and are preserved verbatim for source
// compat. Each symbol pairs `<sym>_run` (kernel launch) with
// `<sym>_can_implement` (pre-launch validation). Strided variants
// use `<sym>_strided_run` and have no `_can_implement` companion
// (matching the existing where family).
//
// Cond semantics across every variant: any non-zero value selects `a`,
// zero selects `b`. The kernel uses `cond != Cond(0)`, which compiles
// to the natural `setp.ne` PTX instruction for any integer width.
//
// Counts: 16 (U32/I64-cond × 4 fp values) + 36 (3 conds × 6 int values)
// + 6 (3 conds × Fp8E4M3) = 58 new symbol pairs. Contig + strided
// makes 116 total `extern "C"` declarations; contig adds another 58
// `_can_implement` companions = 174 declarations in this section.
// ----------------------------------------------------------------------------
// (a) U32-cond × {f32, f16, bf16, f64} — contig + strided
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// `where(cond, a, b)`, u32 cond + f32 values, contig fast path.
///
/// `y = (cond != 0) ? a : b` elementwise. `cond` is interpreted as
/// bool per PyTorch convention (0 → b, non-zero → a) regardless of
/// the underlying integer width.
///
/// # Safety
/// All device pointers must remain valid for the duration of the
/// launch. `cond` must point to at least `numel` `u32`s; `a`, `b`,
/// `y` to at least `numel` `f32`s.
pub fn baracuda_kernels_where_u32cond_f32_run(
numel: i64,
cond: *const c_void,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `where_u32cond_f32`.
pub fn baracuda_kernels_where_u32cond_f32_can_implement(
numel: i64,
cond: *const c_void,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, u32 cond + f32 values, strided / broadcast
/// path. Each operand carries its own stride array.
pub fn baracuda_kernels_where_u32cond_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_cond: *const i64,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_u32cond_f32_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, u32 cond + f64 values, contig fast path.
pub fn baracuda_kernels_where_u32cond_f64_run(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `where_u32cond_f64`.
pub fn baracuda_kernels_where_u32cond_f64_can_implement(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, u32 cond + f64 values, strided / broadcast.
pub fn baracuda_kernels_where_u32cond_f64_strided_run(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64, stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_u32cond_f64_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, u32 cond + f16 values, contig fast path.
pub fn baracuda_kernels_where_u32cond_f16_run(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `where_u32cond_f16`.
pub fn baracuda_kernels_where_u32cond_f16_can_implement(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, u32 cond + f16 values, strided / broadcast.
pub fn baracuda_kernels_where_u32cond_f16_strided_run(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64, stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_u32cond_f16_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, u32 cond + bf16 values, contig fast path.
pub fn baracuda_kernels_where_u32cond_bf16_run(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `where_u32cond_bf16`.
pub fn baracuda_kernels_where_u32cond_bf16_can_implement(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, u32 cond + bf16 values, strided / broadcast.
pub fn baracuda_kernels_where_u32cond_bf16_strided_run(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64, stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_u32cond_bf16_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// (a-cont) I64-cond × {f32, f16, bf16, f64} — contig + strided
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// `where(cond, a, b)`, i64 cond + f32 values, contig fast path.
pub fn baracuda_kernels_where_i64cond_f32_run(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `where_i64cond_f32`.
pub fn baracuda_kernels_where_i64cond_f32_can_implement(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, i64 cond + f32 values, strided / broadcast.
pub fn baracuda_kernels_where_i64cond_f32_strided_run(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64, stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_i64cond_f32_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, i64 cond + f64 values, contig fast path.
pub fn baracuda_kernels_where_i64cond_f64_run(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `where_i64cond_f64`.
pub fn baracuda_kernels_where_i64cond_f64_can_implement(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, i64 cond + f64 values, strided / broadcast.
pub fn baracuda_kernels_where_i64cond_f64_strided_run(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64, stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_i64cond_f64_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, i64 cond + f16 values, contig fast path.
pub fn baracuda_kernels_where_i64cond_f16_run(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `where_i64cond_f16`.
pub fn baracuda_kernels_where_i64cond_f16_can_implement(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, i64 cond + f16 values, strided / broadcast.
pub fn baracuda_kernels_where_i64cond_f16_strided_run(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64, stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_i64cond_f16_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, i64 cond + bf16 values, contig fast path.
pub fn baracuda_kernels_where_i64cond_bf16_run(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `where_i64cond_bf16`.
pub fn baracuda_kernels_where_i64cond_bf16_can_implement(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, i64 cond + bf16 values, strided / broadcast.
pub fn baracuda_kernels_where_i64cond_bf16_strided_run(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64, stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_i64cond_bf16_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// (b) U8-cond × {u8, i8, u32, i16, i32, i64} — contig + strided
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// `where(cond, a, b)`, u8 cond + u8 values, contig fast path.
pub fn baracuda_kernels_where_u8cond_u8_run(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `where_u8cond_u8`.
pub fn baracuda_kernels_where_u8cond_u8_can_implement(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, u8 cond + u8 values, strided / broadcast.
pub fn baracuda_kernels_where_u8cond_u8_strided_run(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64, stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_u8cond_u8_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, u8 cond + i8 values, contig fast path.
pub fn baracuda_kernels_where_u8cond_i8_run(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `where_u8cond_i8`.
pub fn baracuda_kernels_where_u8cond_i8_can_implement(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, u8 cond + i8 values, strided / broadcast.
pub fn baracuda_kernels_where_u8cond_i8_strided_run(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64, stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_u8cond_i8_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, u8 cond + u32 values, contig fast path.
pub fn baracuda_kernels_where_u8cond_u32_run(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `where_u8cond_u32`.
pub fn baracuda_kernels_where_u8cond_u32_can_implement(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, u8 cond + u32 values, strided / broadcast.
pub fn baracuda_kernels_where_u8cond_u32_strided_run(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64, stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_u8cond_u32_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, u8 cond + i16 values, contig fast path.
pub fn baracuda_kernels_where_u8cond_i16_run(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `where_u8cond_i16`.
pub fn baracuda_kernels_where_u8cond_i16_can_implement(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, u8 cond + i16 values, strided / broadcast.
pub fn baracuda_kernels_where_u8cond_i16_strided_run(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64, stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_u8cond_i16_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, u8 cond + i32 values, contig fast path.
pub fn baracuda_kernels_where_u8cond_i32_run(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `where_u8cond_i32`.
pub fn baracuda_kernels_where_u8cond_i32_can_implement(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, u8 cond + i32 values, strided / broadcast.
pub fn baracuda_kernels_where_u8cond_i32_strided_run(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64, stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_u8cond_i32_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, u8 cond + i64 values, contig fast path.
pub fn baracuda_kernels_where_u8cond_i64_run(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check for `where_u8cond_i64`.
pub fn baracuda_kernels_where_u8cond_i64_can_implement(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, u8 cond + i64 values, strided / broadcast.
pub fn baracuda_kernels_where_u8cond_i64_strided_run(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64, stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_u8cond_i64_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// (b-cont) U32-cond × {u8, i8, u32, i16, i32, i64} — contig + strided
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// `where(cond, a, b)`, u32 cond + u8 values, contig fast path.
pub fn baracuda_kernels_where_u32cond_u8_run(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_where_u32cond_u8_can_implement` (baracuda kernels where u32cond u8 can implement).
pub fn baracuda_kernels_where_u32cond_u8_can_implement(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *const c_void,
) -> i32;
/// `baracuda_kernels_where_u32cond_u8_strided_run` (baracuda kernels where u32cond u8 strided run).
pub fn baracuda_kernels_where_u32cond_u8_strided_run(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64, stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_u32cond_u8_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, u32 cond + i8 values, contig fast path.
pub fn baracuda_kernels_where_u32cond_i8_run(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_where_u32cond_i8_can_implement` (baracuda kernels where u32cond i8 can implement).
pub fn baracuda_kernels_where_u32cond_i8_can_implement(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *const c_void,
) -> i32;
/// `baracuda_kernels_where_u32cond_i8_strided_run` (baracuda kernels where u32cond i8 strided run).
pub fn baracuda_kernels_where_u32cond_i8_strided_run(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64, stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_u32cond_i8_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, u32 cond + u32 values, contig fast path.
pub fn baracuda_kernels_where_u32cond_u32_run(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_where_u32cond_u32_can_implement` (baracuda kernels where u32cond u32 can implement).
pub fn baracuda_kernels_where_u32cond_u32_can_implement(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *const c_void,
) -> i32;
/// `baracuda_kernels_where_u32cond_u32_strided_run` (baracuda kernels where u32cond u32 strided run).
pub fn baracuda_kernels_where_u32cond_u32_strided_run(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64, stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_u32cond_u32_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, u32 cond + i16 values, contig fast path.
pub fn baracuda_kernels_where_u32cond_i16_run(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_where_u32cond_i16_can_implement` (baracuda kernels where u32cond i16 can implement).
pub fn baracuda_kernels_where_u32cond_i16_can_implement(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *const c_void,
) -> i32;
/// `baracuda_kernels_where_u32cond_i16_strided_run` (baracuda kernels where u32cond i16 strided run).
pub fn baracuda_kernels_where_u32cond_i16_strided_run(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64, stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_u32cond_i16_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, u32 cond + i32 values, contig fast path.
pub fn baracuda_kernels_where_u32cond_i32_run(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_where_u32cond_i32_can_implement` (baracuda kernels where u32cond i32 can implement).
pub fn baracuda_kernels_where_u32cond_i32_can_implement(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *const c_void,
) -> i32;
/// `baracuda_kernels_where_u32cond_i32_strided_run` (baracuda kernels where u32cond i32 strided run).
pub fn baracuda_kernels_where_u32cond_i32_strided_run(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64, stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_u32cond_i32_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, u32 cond + i64 values, contig fast path.
pub fn baracuda_kernels_where_u32cond_i64_run(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_where_u32cond_i64_can_implement` (baracuda kernels where u32cond i64 can implement).
pub fn baracuda_kernels_where_u32cond_i64_can_implement(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *const c_void,
) -> i32;
/// `baracuda_kernels_where_u32cond_i64_strided_run` (baracuda kernels where u32cond i64 strided run).
pub fn baracuda_kernels_where_u32cond_i64_strided_run(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64, stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_u32cond_i64_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// (b-cont) I64-cond × {u8, i8, u32, i16, i32, i64} — contig + strided
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// `where(cond, a, b)`, i64 cond + u8 values, contig fast path.
pub fn baracuda_kernels_where_i64cond_u8_run(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_where_i64cond_u8_can_implement` (baracuda kernels where i64cond u8 can implement).
pub fn baracuda_kernels_where_i64cond_u8_can_implement(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *const c_void,
) -> i32;
/// `baracuda_kernels_where_i64cond_u8_strided_run` (baracuda kernels where i64cond u8 strided run).
pub fn baracuda_kernels_where_i64cond_u8_strided_run(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64, stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_i64cond_u8_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, i64 cond + i8 values, contig fast path.
pub fn baracuda_kernels_where_i64cond_i8_run(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_where_i64cond_i8_can_implement` (baracuda kernels where i64cond i8 can implement).
pub fn baracuda_kernels_where_i64cond_i8_can_implement(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *const c_void,
) -> i32;
/// `baracuda_kernels_where_i64cond_i8_strided_run` (baracuda kernels where i64cond i8 strided run).
pub fn baracuda_kernels_where_i64cond_i8_strided_run(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64, stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_i64cond_i8_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, i64 cond + u32 values, contig fast path.
pub fn baracuda_kernels_where_i64cond_u32_run(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_where_i64cond_u32_can_implement` (baracuda kernels where i64cond u32 can implement).
pub fn baracuda_kernels_where_i64cond_u32_can_implement(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *const c_void,
) -> i32;
/// `baracuda_kernels_where_i64cond_u32_strided_run` (baracuda kernels where i64cond u32 strided run).
pub fn baracuda_kernels_where_i64cond_u32_strided_run(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64, stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_i64cond_u32_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, i64 cond + i16 values, contig fast path.
pub fn baracuda_kernels_where_i64cond_i16_run(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_where_i64cond_i16_can_implement` (baracuda kernels where i64cond i16 can implement).
pub fn baracuda_kernels_where_i64cond_i16_can_implement(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *const c_void,
) -> i32;
/// `baracuda_kernels_where_i64cond_i16_strided_run` (baracuda kernels where i64cond i16 strided run).
pub fn baracuda_kernels_where_i64cond_i16_strided_run(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64, stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_i64cond_i16_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, i64 cond + i32 values, contig fast path.
pub fn baracuda_kernels_where_i64cond_i32_run(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_where_i64cond_i32_can_implement` (baracuda kernels where i64cond i32 can implement).
pub fn baracuda_kernels_where_i64cond_i32_can_implement(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *const c_void,
) -> i32;
/// `baracuda_kernels_where_i64cond_i32_strided_run` (baracuda kernels where i64cond i32 strided run).
pub fn baracuda_kernels_where_i64cond_i32_strided_run(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64, stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_i64cond_i32_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, i64 cond + i64 values, contig fast path.
pub fn baracuda_kernels_where_i64cond_i64_run(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_where_i64cond_i64_can_implement` (baracuda kernels where i64cond i64 can implement).
pub fn baracuda_kernels_where_i64cond_i64_can_implement(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *const c_void,
) -> i32;
/// `baracuda_kernels_where_i64cond_i64_strided_run` (baracuda kernels where i64cond i64 strided run).
pub fn baracuda_kernels_where_i64cond_i64_strided_run(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64, stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_i64cond_i64_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// (c) {U8, U32, I64}-cond × Fp8E4M3 — contig + strided
// ----------------------------------------------------------------------------
//
// Fp8E4M3 is 1-byte storage; the kernel template instantiated on
// `uint8_t` produces bit-exact output (pure element selection — no FP
// semantics involved). Symbol name carries the `fp8e4m3` value-dtype
// tag so the FFI surface explicitly distinguishes this from the
// generic `u8`-value variants above.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// `where(cond, a, b)`, u8 cond + Fp8E4M3 values, contig fast path.
pub fn baracuda_kernels_where_u8cond_fp8e4m3_run(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_where_u8cond_fp8e4m3_can_implement` (baracuda kernels where u8cond fp8e4m3 can implement).
pub fn baracuda_kernels_where_u8cond_fp8e4m3_can_implement(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *const c_void,
) -> i32;
/// `baracuda_kernels_where_u8cond_fp8e4m3_strided_run` (baracuda kernels where u8cond fp8e4m3 strided run).
pub fn baracuda_kernels_where_u8cond_fp8e4m3_strided_run(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64, stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_u8cond_fp8e4m3_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, u32 cond + Fp8E4M3 values, contig fast path.
pub fn baracuda_kernels_where_u32cond_fp8e4m3_run(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_where_u32cond_fp8e4m3_can_implement` (baracuda kernels where u32cond fp8e4m3 can implement).
pub fn baracuda_kernels_where_u32cond_fp8e4m3_can_implement(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *const c_void,
) -> i32;
/// `baracuda_kernels_where_u32cond_fp8e4m3_strided_run` (baracuda kernels where u32cond fp8e4m3 strided run).
pub fn baracuda_kernels_where_u32cond_fp8e4m3_strided_run(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64, stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_u32cond_fp8e4m3_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// `where(cond, a, b)`, i64 cond + Fp8E4M3 values, contig fast path.
pub fn baracuda_kernels_where_i64cond_fp8e4m3_run(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_where_i64cond_fp8e4m3_can_implement` (baracuda kernels where i64cond fp8e4m3 can implement).
pub fn baracuda_kernels_where_i64cond_fp8e4m3_can_implement(
numel: i64, cond: *const c_void, a: *const c_void, b: *const c_void,
y: *const c_void,
) -> i32;
/// `baracuda_kernels_where_i64cond_fp8e4m3_strided_run` (baracuda kernels where i64cond fp8e4m3 strided run).
pub fn baracuda_kernels_where_i64cond_fp8e4m3_strided_run(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64, stride_b: *const i64,
stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_i64cond_fp8e4m3_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_cond: *const i64, stride_a: *const i64,
stride_b: *const i64, stride_y: *const i64,
cond: *const c_void, a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
}
// ============================================================================
// Elementwise — where backward (heterogeneous-dtype ternary BW, u8 cond + T → T,T)
// ============================================================================
//
// Forward: `y = cond ? a : b`. Backward (cond is non-differentiable):
// da[i] = cond[i] ? dy[i] : 0
// db[i] = cond[i] ? 0 : dy[i]
//
// Pure mask + copy: bit-exact at every dtype. Trailblazer is contig-only
// — broadcasting on dy / da / db is the caller's responsibility (it's
// what the autograd reduction step does upstream of this kernel anyway).
//
// All 4 FP value dtypes wired: {f32, f16, bf16, f64}.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// `where` backward, f32. Writes `da = cond ? dy : 0` and
/// `db = cond ? 0 : dy`.
///
/// # Safety
/// All device pointers must remain valid for the duration of the
/// launch. `cond` must point to at least `numel` `u8`s; `dy`, `da`,
/// `db` to at least `numel` `f32`s.
pub fn baracuda_kernels_where_backward_f32_run(
numel: i64,
cond: *const c_void,
dy: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_backward_f32_can_implement(
numel: i64,
cond: *const c_void, dy: *const c_void,
da: *const c_void, db: *const c_void,
) -> i32;
/// `where` backward, f16.
pub fn baracuda_kernels_where_backward_f16_run(
numel: i64,
cond: *const c_void,
dy: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_backward_f16_can_implement(
numel: i64,
cond: *const c_void, dy: *const c_void,
da: *const c_void, db: *const c_void,
) -> i32;
/// `where` backward, bf16.
pub fn baracuda_kernels_where_backward_bf16_run(
numel: i64,
cond: *const c_void,
dy: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_backward_bf16_can_implement(
numel: i64,
cond: *const c_void, dy: *const c_void,
da: *const c_void, db: *const c_void,
) -> i32;
/// `where` backward, f64.
pub fn baracuda_kernels_where_backward_f64_run(
numel: i64,
cond: *const c_void,
dy: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_where_backward_f64_can_implement(
numel: i64,
cond: *const c_void, dy: *const c_void,
da: *const c_void, db: *const c_void,
) -> i32;
}
// ============================================================================
// Elementwise — binary backward ops (Phase 3 backward family)
// ============================================================================
//
// `(da, db) = backward(dy, [saved tensors per op])`. Two ABI shapes:
//
// * **No-save backward** (Add, Sub) — gradient depends only on `dy`.
// ABI: `(numel, dy, da, db, workspace, workspace_bytes, stream)`.
// * **Saves-using backward** (Mul, Div) — gradient references the
// saved forward inputs `a` and `b`.
// ABI: `(numel, dy, a, b, da, db, workspace, workspace_bytes, stream)`.
//
// All four ops are wired across `{f32, f16, bf16, f64}`.
//
// Status codes match the elementwise forward family: 0 success, 2 invalid
// problem (e.g. negative `numel`, null pointer), 5 internal kernel error.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Add backward, f32. Writes `da = dy` and `db = dy`.
pub fn baracuda_kernels_binary_add_backward_f32_run(
numel: i64,
dy: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_add_backward_f32_can_implement` (baracuda kernels binary add backward f32 can implement).
pub fn baracuda_kernels_binary_add_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Add backward, f16.
pub fn baracuda_kernels_binary_add_backward_f16_run(
numel: i64,
dy: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_add_backward_f16_can_implement` (baracuda kernels binary add backward f16 can implement).
pub fn baracuda_kernels_binary_add_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Add backward, bf16.
pub fn baracuda_kernels_binary_add_backward_bf16_run(
numel: i64,
dy: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_add_backward_bf16_can_implement` (baracuda kernels binary add backward bf16 can implement).
pub fn baracuda_kernels_binary_add_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Add backward, f64.
pub fn baracuda_kernels_binary_add_backward_f64_run(
numel: i64,
dy: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_add_backward_f64_can_implement` (baracuda kernels binary add backward f64 can implement).
pub fn baracuda_kernels_binary_add_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Sub backward, f32. Writes `da = dy` and `db = -dy`.
pub fn baracuda_kernels_binary_sub_backward_f32_run(
numel: i64,
dy: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_sub_backward_f32_can_implement` (baracuda kernels binary sub backward f32 can implement).
pub fn baracuda_kernels_binary_sub_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Sub backward, f16.
pub fn baracuda_kernels_binary_sub_backward_f16_run(
numel: i64,
dy: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_sub_backward_f16_can_implement` (baracuda kernels binary sub backward f16 can implement).
pub fn baracuda_kernels_binary_sub_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Sub backward, bf16.
pub fn baracuda_kernels_binary_sub_backward_bf16_run(
numel: i64,
dy: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_sub_backward_bf16_can_implement` (baracuda kernels binary sub backward bf16 can implement).
pub fn baracuda_kernels_binary_sub_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Sub backward, f64.
pub fn baracuda_kernels_binary_sub_backward_f64_run(
numel: i64,
dy: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_sub_backward_f64_can_implement` (baracuda kernels binary sub backward f64 can implement).
pub fn baracuda_kernels_binary_sub_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Mul backward, f32. Writes `da = dy * b` and `db = dy * a`.
/// Both saved tensors `a` and `b` must be non-null.
pub fn baracuda_kernels_binary_mul_backward_f32_run(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_mul_backward_f32_can_implement` (baracuda kernels binary mul backward f32 can implement).
pub fn baracuda_kernels_binary_mul_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Mul backward, f16.
pub fn baracuda_kernels_binary_mul_backward_f16_run(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_mul_backward_f16_can_implement` (baracuda kernels binary mul backward f16 can implement).
pub fn baracuda_kernels_binary_mul_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Mul backward, bf16.
pub fn baracuda_kernels_binary_mul_backward_bf16_run(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_mul_backward_bf16_can_implement` (baracuda kernels binary mul backward bf16 can implement).
pub fn baracuda_kernels_binary_mul_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Mul backward, f64.
pub fn baracuda_kernels_binary_mul_backward_f64_run(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_mul_backward_f64_can_implement` (baracuda kernels binary mul backward f64 can implement).
pub fn baracuda_kernels_binary_mul_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Div backward, f32. Writes `da = dy / b` and `db = -dy * a / b²`.
/// Both saved tensors `a` and `b` must be non-null; callers must
/// also ensure `b[i] != 0` for every cell.
pub fn baracuda_kernels_binary_div_backward_f32_run(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_div_backward_f32_can_implement` (baracuda kernels binary div backward f32 can implement).
pub fn baracuda_kernels_binary_div_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Div backward, f16.
pub fn baracuda_kernels_binary_div_backward_f16_run(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_div_backward_f16_can_implement` (baracuda kernels binary div backward f16 can implement).
pub fn baracuda_kernels_binary_div_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Div backward, bf16.
pub fn baracuda_kernels_binary_div_backward_bf16_run(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_div_backward_bf16_can_implement` (baracuda kernels binary div backward bf16 can implement).
pub fn baracuda_kernels_binary_div_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Div backward, f64.
pub fn baracuda_kernels_binary_div_backward_f64_run(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_div_backward_f64_can_implement` (baracuda kernels binary div backward f64 can implement).
pub fn baracuda_kernels_binary_div_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Pow backward, f32. `da = dy * b * a^(b-1)`, `db = dy * a^b * ln(a)`.
/// Caller responsible for guarding against undefined regions
/// (`a < 0` non-integer `b`, or `a == 0` with `b < 1`).
pub fn baracuda_kernels_binary_pow_backward_f32_run(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_pow_backward_f32_can_implement` (baracuda kernels binary pow backward f32 can implement).
pub fn baracuda_kernels_binary_pow_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Pow backward, f16.
pub fn baracuda_kernels_binary_pow_backward_f16_run(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_pow_backward_f16_can_implement` (baracuda kernels binary pow backward f16 can implement).
pub fn baracuda_kernels_binary_pow_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Pow backward, bf16.
pub fn baracuda_kernels_binary_pow_backward_bf16_run(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_pow_backward_bf16_can_implement` (baracuda kernels binary pow backward bf16 can implement).
pub fn baracuda_kernels_binary_pow_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Pow backward, f64.
pub fn baracuda_kernels_binary_pow_backward_f64_run(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_pow_backward_f64_can_implement` (baracuda kernels binary pow backward f64 can implement).
pub fn baracuda_kernels_binary_pow_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Atan2 backward, f32. `denom = a²+b²`, `da = dy*b/denom`,
/// `db = -dy*a/denom`. Caller responsible for guarding against
/// `a == 0 && b == 0` (denom == 0).
pub fn baracuda_kernels_binary_atan2_backward_f32_run(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_atan2_backward_f32_can_implement` (baracuda kernels binary atan2 backward f32 can implement).
pub fn baracuda_kernels_binary_atan2_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Atan2 backward, f16.
pub fn baracuda_kernels_binary_atan2_backward_f16_run(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_atan2_backward_f16_can_implement` (baracuda kernels binary atan2 backward f16 can implement).
pub fn baracuda_kernels_binary_atan2_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Atan2 backward, bf16.
pub fn baracuda_kernels_binary_atan2_backward_bf16_run(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_atan2_backward_bf16_can_implement` (baracuda kernels binary atan2 backward bf16 can implement).
pub fn baracuda_kernels_binary_atan2_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Atan2 backward, f64.
pub fn baracuda_kernels_binary_atan2_backward_f64_run(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_atan2_backward_f64_can_implement` (baracuda kernels binary atan2 backward f64 can implement).
pub fn baracuda_kernels_binary_atan2_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Hypot backward, f32. `y = sqrt(a²+b²)` is reconstructed inside
/// the kernel from saved `a` and `b` (no saved-y slot in
/// `BinaryBackwardArgs`); `da = dy*a/y`, `db = dy*b/y`. Caller
/// responsible for guarding against `a == 0 && b == 0` (y == 0).
pub fn baracuda_kernels_binary_hypot_backward_f32_run(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_hypot_backward_f32_can_implement` (baracuda kernels binary hypot backward f32 can implement).
pub fn baracuda_kernels_binary_hypot_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Hypot backward, f16.
pub fn baracuda_kernels_binary_hypot_backward_f16_run(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_hypot_backward_f16_can_implement` (baracuda kernels binary hypot backward f16 can implement).
pub fn baracuda_kernels_binary_hypot_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Hypot backward, bf16.
pub fn baracuda_kernels_binary_hypot_backward_bf16_run(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_hypot_backward_bf16_can_implement` (baracuda kernels binary hypot backward bf16 can implement).
pub fn baracuda_kernels_binary_hypot_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Hypot backward, f64.
pub fn baracuda_kernels_binary_hypot_backward_f64_run(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_hypot_backward_f64_can_implement` (baracuda kernels binary hypot backward f64 can implement).
pub fn baracuda_kernels_binary_hypot_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Maximum backward, f32. Tie-break: split `dy` evenly on `a == b`;
/// NaN inputs propagate `dy` to both. Saved `a` and `b` are used purely
/// as references for the comparison.
pub fn baracuda_kernels_binary_maximum_backward_f32_run(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_maximum_backward_f32_can_implement` (baracuda kernels binary maximum backward f32 can implement).
pub fn baracuda_kernels_binary_maximum_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Maximum backward, f16.
pub fn baracuda_kernels_binary_maximum_backward_f16_run(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_maximum_backward_f16_can_implement` (baracuda kernels binary maximum backward f16 can implement).
pub fn baracuda_kernels_binary_maximum_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Maximum backward, bf16.
pub fn baracuda_kernels_binary_maximum_backward_bf16_run(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_maximum_backward_bf16_can_implement` (baracuda kernels binary maximum backward bf16 can implement).
pub fn baracuda_kernels_binary_maximum_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Maximum backward, f64.
pub fn baracuda_kernels_binary_maximum_backward_f64_run(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_maximum_backward_f64_can_implement` (baracuda kernels binary maximum backward f64 can implement).
pub fn baracuda_kernels_binary_maximum_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Minimum backward, f32. Tie-break: split `dy` evenly on `a == b`;
/// NaN inputs propagate `dy` to both. Saved `a` and `b` are used purely
/// as references for the comparison.
pub fn baracuda_kernels_binary_minimum_backward_f32_run(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_minimum_backward_f32_can_implement` (baracuda kernels binary minimum backward f32 can implement).
pub fn baracuda_kernels_binary_minimum_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Minimum backward, f16.
pub fn baracuda_kernels_binary_minimum_backward_f16_run(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_minimum_backward_f16_can_implement` (baracuda kernels binary minimum backward f16 can implement).
pub fn baracuda_kernels_binary_minimum_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Minimum backward, bf16.
pub fn baracuda_kernels_binary_minimum_backward_bf16_run(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_minimum_backward_bf16_can_implement` (baracuda kernels binary minimum backward bf16 can implement).
pub fn baracuda_kernels_binary_minimum_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
/// Minimum backward, f64.
pub fn baracuda_kernels_binary_minimum_backward_f64_run(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *mut c_void,
db: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_minimum_backward_f64_can_implement` (baracuda kernels binary minimum backward f64 can implement).
pub fn baracuda_kernels_binary_minimum_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
a: *const c_void,
b: *const c_void,
da: *const c_void,
db: *const c_void,
) -> i32;
}
// ============================================================================
// Elementwise — unary backward ops (Phase 3 unary-backward trailblazer)
// ============================================================================
//
// `dx = f'(saved) * dy` for the unary op family. The kernel ABI is
// uniform — one saved tensor of dtype `T` and one gradient input `dy`,
// producing one gradient output `dx`. Which save (`x` or `y`) the
// caller must pass depends on the op's BW formula:
//
// * Saved-x ops (Sin, Cos, Log, ...): caller passes `x` as `saved`.
// Example: Sin BW: `dx = dy * cos(x)`.
// * Saved-y ops (Exp, Sigmoid, Tanh, Sqrt, ...): caller passes `y`
// as `saved`. Example: Exp BW: `dx = dy * y`.
//
// Trailblazer scope: Sin BW × f32 (saved-x) and Exp BW × f32 (saved-y).
// Other ops / dtypes land in fanout.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Sin backward, f32. `dx = dy * cos(x)`. Caller must pass the
/// forward input `x` as `saved`.
pub fn baracuda_kernels_unary_sin_backward_f32_run(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sin_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_sin_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Exp backward, f32. `dx = dy * y`. Caller must pass the forward
/// output `y` as `saved`.
pub fn baracuda_kernels_unary_exp_backward_f32_run(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_exp_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_exp_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Exp backward, f16.
pub fn baracuda_kernels_unary_exp_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_exp_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_exp_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Exp backward, bf16.
pub fn baracuda_kernels_unary_exp_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_exp_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_exp_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Exp backward, f64.
pub fn baracuda_kernels_unary_exp_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_exp_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_exp_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Expm1 backward, f32. `dx = dy * (y + 1)`. Saved-y.
pub fn baracuda_kernels_unary_expm1_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_expm1_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_expm1_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Expm1 backward, f16.
pub fn baracuda_kernels_unary_expm1_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_expm1_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_expm1_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Expm1 backward, bf16.
pub fn baracuda_kernels_unary_expm1_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_expm1_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_expm1_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Expm1 backward, f64.
pub fn baracuda_kernels_unary_expm1_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_expm1_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_expm1_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Tanh backward, f32. `dx = dy * (1 - y²)`. Saved-y.
pub fn baracuda_kernels_unary_tanh_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tanh_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_tanh_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Tanh backward, f16.
pub fn baracuda_kernels_unary_tanh_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tanh_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_tanh_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Tanh backward, bf16.
pub fn baracuda_kernels_unary_tanh_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tanh_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_tanh_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Tanh backward, f64.
pub fn baracuda_kernels_unary_tanh_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tanh_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_tanh_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Sigmoid backward, f32. `dx = dy * y * (1 - y)`. Saved-y.
pub fn baracuda_kernels_unary_sigmoid_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sigmoid_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_sigmoid_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Sigmoid backward, f16.
pub fn baracuda_kernels_unary_sigmoid_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sigmoid_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_sigmoid_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Sigmoid backward, bf16.
pub fn baracuda_kernels_unary_sigmoid_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sigmoid_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_sigmoid_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Sigmoid backward, f64.
pub fn baracuda_kernels_unary_sigmoid_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sigmoid_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_sigmoid_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Sqrt backward, f32. `dx = dy / (2 * y)`. Saved-y. Callers must
/// ensure `y[i] != 0`.
pub fn baracuda_kernels_unary_sqrt_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sqrt_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_sqrt_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Sqrt backward, f16.
pub fn baracuda_kernels_unary_sqrt_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sqrt_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_sqrt_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Sqrt backward, bf16.
pub fn baracuda_kernels_unary_sqrt_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sqrt_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_sqrt_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Sqrt backward, f64.
pub fn baracuda_kernels_unary_sqrt_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sqrt_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_sqrt_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Rsqrt backward, f32. `dx = -0.5 * dy * y³`. Saved-y.
pub fn baracuda_kernels_unary_rsqrt_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_rsqrt_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_rsqrt_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Rsqrt backward, f16.
pub fn baracuda_kernels_unary_rsqrt_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_rsqrt_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_rsqrt_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Rsqrt backward, bf16.
pub fn baracuda_kernels_unary_rsqrt_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_rsqrt_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_rsqrt_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Rsqrt backward, f64.
pub fn baracuda_kernels_unary_rsqrt_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_rsqrt_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_rsqrt_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Sin backward fanout (saves-x, transcendental) ----
/// Sin backward, f16.
pub fn baracuda_kernels_unary_sin_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sin_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_sin_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Sin backward, bf16.
pub fn baracuda_kernels_unary_sin_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sin_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_sin_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Sin backward, f64.
pub fn baracuda_kernels_unary_sin_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sin_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_sin_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Log backward (saves-x, no transcendental) ----
/// Log backward, f32. `dx = dy / x`.
pub fn baracuda_kernels_unary_log_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_log_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Log backward, f16.
pub fn baracuda_kernels_unary_log_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_log_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Log backward, bf16.
pub fn baracuda_kernels_unary_log_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_log_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Log backward, f64.
pub fn baracuda_kernels_unary_log_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_log_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Log1p backward (saves-x, no transcendental) ----
/// Log1p backward, f32. `dx = dy / (1 + x)`.
pub fn baracuda_kernels_unary_log1p_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log1p_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_log1p_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Log1p backward, f16.
pub fn baracuda_kernels_unary_log1p_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log1p_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_log1p_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Log1p backward, bf16.
pub fn baracuda_kernels_unary_log1p_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log1p_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_log1p_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Log1p backward, f64.
pub fn baracuda_kernels_unary_log1p_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log1p_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_log1p_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Log2 backward (saves-x, constant ln(2)) ----
/// Log2 backward, f32. `dx = dy / (x * ln(2))`.
pub fn baracuda_kernels_unary_log2_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log2_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_log2_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Log2 backward, f16.
pub fn baracuda_kernels_unary_log2_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log2_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_log2_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Log2 backward, bf16.
pub fn baracuda_kernels_unary_log2_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log2_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_log2_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Log2 backward, f64.
pub fn baracuda_kernels_unary_log2_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log2_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_log2_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Log10 backward (saves-x, constant ln(10)) ----
/// Log10 backward, f32. `dx = dy / (x * ln(10))`.
pub fn baracuda_kernels_unary_log10_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log10_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_log10_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Log10 backward, f16.
pub fn baracuda_kernels_unary_log10_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log10_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_log10_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Log10 backward, bf16.
pub fn baracuda_kernels_unary_log10_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log10_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_log10_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Log10 backward, f64.
pub fn baracuda_kernels_unary_log10_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log10_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_log10_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Atan backward (saves-x, no transcendental) ----
/// Atan backward, f32. `dx = dy / (1 + x²)`.
pub fn baracuda_kernels_unary_atan_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_atan_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_atan_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Atan backward, f16.
pub fn baracuda_kernels_unary_atan_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_atan_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_atan_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Atan backward, bf16.
pub fn baracuda_kernels_unary_atan_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_atan_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_atan_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Atan backward, f64.
pub fn baracuda_kernels_unary_atan_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_atan_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_atan_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Cos backward (saves-x, transcendental) ----
/// Cos backward, f32. `dx = -dy * sin(x)`. Saved-x.
pub fn baracuda_kernels_unary_cos_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cos_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_cos_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Cos backward, f16.
pub fn baracuda_kernels_unary_cos_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cos_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_cos_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Cos backward, bf16.
pub fn baracuda_kernels_unary_cos_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cos_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_cos_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Cos backward, f64.
pub fn baracuda_kernels_unary_cos_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cos_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_cos_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Tan backward (saves-x, transcendental) ----
/// Tan backward, f32. `dx = dy * (1 + tan(x)²)`. Saved-x.
pub fn baracuda_kernels_unary_tan_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tan_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_tan_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Tan backward, f16.
pub fn baracuda_kernels_unary_tan_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tan_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_tan_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Tan backward, bf16.
pub fn baracuda_kernels_unary_tan_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tan_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_tan_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Tan backward, f64.
pub fn baracuda_kernels_unary_tan_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tan_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_tan_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Sinh backward (saves-x, transcendental) ----
/// Sinh backward, f32. `dx = dy * cosh(x)`. Saved-x.
pub fn baracuda_kernels_unary_sinh_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sinh_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_sinh_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Sinh backward, f16.
pub fn baracuda_kernels_unary_sinh_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sinh_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_sinh_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Sinh backward, bf16.
pub fn baracuda_kernels_unary_sinh_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sinh_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_sinh_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Sinh backward, f64.
pub fn baracuda_kernels_unary_sinh_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sinh_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_sinh_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Cosh backward (saves-x, transcendental) ----
/// Cosh backward, f32. `dx = dy * sinh(x)`. Saved-x.
pub fn baracuda_kernels_unary_cosh_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cosh_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_cosh_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Cosh backward, f16.
pub fn baracuda_kernels_unary_cosh_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cosh_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_cosh_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Cosh backward, bf16.
pub fn baracuda_kernels_unary_cosh_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cosh_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_cosh_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Cosh backward, f64.
pub fn baracuda_kernels_unary_cosh_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cosh_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_cosh_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Asin backward (saves-x, sqrt) ----
/// Asin backward, f32. `dx = dy / sqrt(1 - x²)`. Saved-x. Domain: `|x| < 1`.
pub fn baracuda_kernels_unary_asin_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_asin_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_asin_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Asin backward, f16.
pub fn baracuda_kernels_unary_asin_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_asin_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_asin_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Asin backward, bf16.
pub fn baracuda_kernels_unary_asin_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_asin_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_asin_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Asin backward, f64.
pub fn baracuda_kernels_unary_asin_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_asin_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_asin_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Acos backward (saves-x, sqrt) ----
/// Acos backward, f32. `dx = -dy / sqrt(1 - x²)`. Saved-x. Domain: `|x| < 1`.
pub fn baracuda_kernels_unary_acos_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_acos_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_acos_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Acos backward, f16.
pub fn baracuda_kernels_unary_acos_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_acos_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_acos_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Acos backward, bf16.
pub fn baracuda_kernels_unary_acos_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_acos_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_acos_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Acos backward, f64.
pub fn baracuda_kernels_unary_acos_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_acos_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_acos_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Asinh backward (saves-x, sqrt) ----
/// Asinh backward, f32. `dx = dy / sqrt(1 + x²)`. Saved-x.
pub fn baracuda_kernels_unary_asinh_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_asinh_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_asinh_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Asinh backward, f16.
pub fn baracuda_kernels_unary_asinh_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_asinh_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_asinh_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Asinh backward, bf16.
pub fn baracuda_kernels_unary_asinh_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_asinh_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_asinh_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Asinh backward, f64.
pub fn baracuda_kernels_unary_asinh_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_asinh_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_asinh_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Acosh backward (saves-x, sqrt) ----
/// Acosh backward, f32. `dx = dy / sqrt(x² - 1)`. Saved-x. Domain: `x > 1`.
pub fn baracuda_kernels_unary_acosh_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_acosh_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_acosh_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Acosh backward, f16.
pub fn baracuda_kernels_unary_acosh_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_acosh_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_acosh_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Acosh backward, bf16.
pub fn baracuda_kernels_unary_acosh_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_acosh_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_acosh_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Acosh backward, f64.
pub fn baracuda_kernels_unary_acosh_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_acosh_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_acosh_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Atanh backward (saves-x, no transcendental) ----
/// Atanh backward, f32. `dx = dy / (1 - x²)`. Saved-x. Domain: `|x| < 1`.
pub fn baracuda_kernels_unary_atanh_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_atanh_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_atanh_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Atanh backward, f16.
pub fn baracuda_kernels_unary_atanh_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_atanh_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_atanh_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Atanh backward, bf16.
pub fn baracuda_kernels_unary_atanh_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_atanh_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_atanh_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Atanh backward, f64.
pub fn baracuda_kernels_unary_atanh_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_atanh_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_atanh_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Square backward (saves-x, dy * 2 * x) ----
/// Square backward, f32. `dx = dy * 2 * x`.
pub fn baracuda_kernels_unary_square_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_square_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_square_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Square backward, f16.
pub fn baracuda_kernels_unary_square_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_square_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_square_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Square backward, bf16.
pub fn baracuda_kernels_unary_square_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_square_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_square_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Square backward, f64.
pub fn baracuda_kernels_unary_square_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_square_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_square_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Cube backward (saves-x, dy * 3 * x²) ----
/// Cube backward, f32. `dx = dy * 3 * x²`.
pub fn baracuda_kernels_unary_cube_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cube_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_cube_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Cube backward, f16.
pub fn baracuda_kernels_unary_cube_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cube_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_cube_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Cube backward, bf16.
pub fn baracuda_kernels_unary_cube_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cube_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_cube_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Cube backward, f64.
pub fn baracuda_kernels_unary_cube_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cube_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_cube_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Exp2 backward (saves-y, dy * y * ln(2)) ----
/// Exp2 backward, f32. `dx = dy * y * ln(2)`.
pub fn baracuda_kernels_unary_exp2_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_exp2_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_exp2_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Exp2 backward, f16.
pub fn baracuda_kernels_unary_exp2_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_exp2_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_exp2_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Exp2 backward, bf16.
pub fn baracuda_kernels_unary_exp2_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_exp2_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_exp2_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Exp2 backward, f64.
pub fn baracuda_kernels_unary_exp2_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_exp2_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_exp2_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Tanhshrink backward (saves-x, dy * tanh(x)²) ----
/// Tanhshrink backward, f32. `dx = dy * tanh(x)²`.
pub fn baracuda_kernels_unary_tanhshrink_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tanhshrink_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_tanhshrink_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Tanhshrink backward, f16.
pub fn baracuda_kernels_unary_tanhshrink_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tanhshrink_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_tanhshrink_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Tanhshrink backward, bf16.
pub fn baracuda_kernels_unary_tanhshrink_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tanhshrink_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_tanhshrink_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Tanhshrink backward, f64.
pub fn baracuda_kernels_unary_tanhshrink_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tanhshrink_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_tanhshrink_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Logit backward (saves-x, dy / (x * (1 - x))) ----
/// Logit backward, f32. `dx = dy / (x * (1 - x))`. Domain `0 < x < 1`.
pub fn baracuda_kernels_unary_logit_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_logit_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_logit_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Logit backward, f16.
pub fn baracuda_kernels_unary_logit_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_logit_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_logit_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Logit backward, bf16.
pub fn baracuda_kernels_unary_logit_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_logit_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_logit_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Logit backward, f64.
pub fn baracuda_kernels_unary_logit_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_logit_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_logit_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Reciprocal backward (saves-x, -dy / x²) ----
/// Reciprocal backward, f32. `dx = -dy / x²`. Domain `x != 0`.
pub fn baracuda_kernels_unary_reciprocal_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_reciprocal_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_reciprocal_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Reciprocal backward, f16.
pub fn baracuda_kernels_unary_reciprocal_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_reciprocal_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_reciprocal_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Reciprocal backward, bf16.
pub fn baracuda_kernels_unary_reciprocal_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_reciprocal_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_reciprocal_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Reciprocal backward, f64.
pub fn baracuda_kernels_unary_reciprocal_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_reciprocal_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_reciprocal_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Erf backward (saves-x, transcendental, 2/√π * exp(-x²)) ----
/// Erf backward, f32. `dx = dy * (2/√π) * exp(-x²)`.
pub fn baracuda_kernels_unary_erf_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_erf_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_erf_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Erf backward, f16.
pub fn baracuda_kernels_unary_erf_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_erf_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_erf_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Erf backward, bf16.
pub fn baracuda_kernels_unary_erf_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_erf_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_erf_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Erf backward, f64.
pub fn baracuda_kernels_unary_erf_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_erf_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_erf_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Erfc backward (saves-x, transcendental, -2/√π * exp(-x²)) ----
/// Erfc backward, f32. `dx = -dy * (2/√π) * exp(-x²)`.
pub fn baracuda_kernels_unary_erfc_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_erfc_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_erfc_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Erfc backward, f16.
pub fn baracuda_kernels_unary_erfc_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_erfc_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_erfc_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Erfc backward, bf16.
pub fn baracuda_kernels_unary_erfc_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_erfc_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_erfc_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Erfc backward, f64.
pub fn baracuda_kernels_unary_erfc_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_erfc_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_erfc_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Activation BW (saved-x, piecewise — Category B' trailblazer + fanout) ----
/// ReLU backward, f32. `dx = (x > 0) ? dy : 0`. Saved-x. This is the
/// activation-BW trailblazer — its aliasing contract carries over to
/// every other `unary_<op>_backward_<dt>_run` (gelu, silu, tanh,
/// sigmoid, elu, leaky_relu, mish, hardswish, hardsigmoid, gelu_tanh,
/// erf, erfc, etc.) across all dtypes, both saved-x and saved-y
/// variants.
///
/// **Aliasing (Phase 64)**: aliasing `dx` with `saved` or `dy` (or
/// both, if `saved == dy`) is safe — the kernel evaluates
/// `dx[i] = f(saved[i]) * dy[i]` (or piecewise variant for ReLU)
/// with each thread touching only its own index `i` (read
/// `saved[i]` + `dy[i]` before write `dx[i]`). Callers
/// implementing in-place activation gradient ops (e.g. an
/// autograd framework reusing the saved-x or dy buffer for the
/// gradient) can dispatch the forward symbol with
/// `dx_ptr == saved_ptr` (or `dx_ptr == dy_ptr`) without a
/// dedicated `_inplace_` variant. This contract is stable across
/// baracuda versions and applies to both the saved-x family
/// (ReLU/GELU/SiLU/ELU/HardSwish/HardSigmoid/Mish/LeakyReLU/Erf/Erfc)
/// and the saved-y family (Sigmoid/Tanh).
pub fn baracuda_kernels_unary_relu_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_relu_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_relu_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// ReLU backward, f16.
pub fn baracuda_kernels_unary_relu_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_relu_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_relu_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// ReLU backward, bf16.
pub fn baracuda_kernels_unary_relu_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_relu_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_relu_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// ReLU backward, f64.
pub fn baracuda_kernels_unary_relu_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_relu_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_relu_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Hardtanh backward, f32. `dx = (-1 < x < 1) ? dy : 0`. Saved-x.
pub fn baracuda_kernels_unary_hardtanh_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardtanh_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_hardtanh_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Hardtanh backward, f16.
pub fn baracuda_kernels_unary_hardtanh_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardtanh_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_hardtanh_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Hardtanh backward, bf16.
pub fn baracuda_kernels_unary_hardtanh_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardtanh_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_hardtanh_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Hardtanh backward, f64.
pub fn baracuda_kernels_unary_hardtanh_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardtanh_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_hardtanh_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// ReLU6 backward, f32. `dx = (0 < x < 6) ? dy : 0`. Saved-x.
pub fn baracuda_kernels_unary_relu6_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_relu6_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_relu6_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// ReLU6 backward, f16.
pub fn baracuda_kernels_unary_relu6_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_relu6_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_relu6_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// ReLU6 backward, bf16.
pub fn baracuda_kernels_unary_relu6_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_relu6_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_relu6_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// ReLU6 backward, f64.
pub fn baracuda_kernels_unary_relu6_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_relu6_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_relu6_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Hardsigmoid backward, f32. `dx = (-3 < x < 3) ? dy / 6 : 0`. Saved-x.
pub fn baracuda_kernels_unary_hardsigmoid_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardsigmoid_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_hardsigmoid_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Hardsigmoid backward, f16.
pub fn baracuda_kernels_unary_hardsigmoid_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardsigmoid_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_hardsigmoid_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Hardsigmoid backward, bf16.
pub fn baracuda_kernels_unary_hardsigmoid_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardsigmoid_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_hardsigmoid_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Hardsigmoid backward, f64.
pub fn baracuda_kernels_unary_hardsigmoid_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardsigmoid_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_hardsigmoid_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Hardswish backward, f32. Three-region piecewise + `(2x+3)/6` middle. Saved-x.
pub fn baracuda_kernels_unary_hardswish_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardswish_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_hardswish_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Hardswish backward, f16.
pub fn baracuda_kernels_unary_hardswish_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardswish_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_hardswish_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Hardswish backward, bf16.
pub fn baracuda_kernels_unary_hardswish_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardswish_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_hardswish_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Hardswish backward, f64.
pub fn baracuda_kernels_unary_hardswish_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardswish_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_hardswish_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Softplus backward, f32. `dx = dy / (1 + exp(-x))`. Saved-x.
pub fn baracuda_kernels_unary_softplus_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softplus_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_softplus_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Softplus backward, f16.
pub fn baracuda_kernels_unary_softplus_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softplus_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_softplus_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Softplus backward, bf16.
pub fn baracuda_kernels_unary_softplus_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softplus_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_softplus_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Softplus backward, f64.
pub fn baracuda_kernels_unary_softplus_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softplus_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_softplus_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// SiLU (Swish) backward, f32. `dx = dy * s * (1 + x*(1-s))` with `s = sigmoid(x)`. Saved-x.
pub fn baracuda_kernels_unary_silu_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_silu_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_silu_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// SiLU backward, f16.
pub fn baracuda_kernels_unary_silu_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_silu_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_silu_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// SiLU backward, bf16.
pub fn baracuda_kernels_unary_silu_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_silu_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_silu_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// SiLU backward, f64.
pub fn baracuda_kernels_unary_silu_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_silu_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_silu_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Mish backward, f32. `dx = dy * (tanh(sp) + x*s*(1 - tanh(sp)^2))`, `sp = softplus(x)`. Saved-x.
pub fn baracuda_kernels_unary_mish_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_mish_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_mish_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Mish backward, f16.
pub fn baracuda_kernels_unary_mish_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_mish_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_mish_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Mish backward, bf16.
pub fn baracuda_kernels_unary_mish_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_mish_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_mish_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Mish backward, f64.
pub fn baracuda_kernels_unary_mish_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_mish_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_mish_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// GELU (exact / erf-based) backward, f32. `dx = dy * (Φ(x) + x*φ(x))`. Saved-x.
pub fn baracuda_kernels_unary_gelu_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_gelu_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_gelu_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// GELU (erf-based) backward, f16.
pub fn baracuda_kernels_unary_gelu_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_gelu_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_gelu_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// GELU (erf-based) backward, bf16.
pub fn baracuda_kernels_unary_gelu_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_gelu_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_gelu_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// GELU (erf-based) backward, f64.
pub fn baracuda_kernels_unary_gelu_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_gelu_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_gelu_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// GELU (tanh approximation) backward, f32. Saved-x.
pub fn baracuda_kernels_unary_gelu_tanh_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_gelu_tanh_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_gelu_tanh_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// GELU (tanh approximation) backward, f16.
pub fn baracuda_kernels_unary_gelu_tanh_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_gelu_tanh_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_gelu_tanh_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// GELU (tanh approximation) backward, bf16.
pub fn baracuda_kernels_unary_gelu_tanh_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_gelu_tanh_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_gelu_tanh_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// GELU (tanh approximation) backward, f64.
pub fn baracuda_kernels_unary_gelu_tanh_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_gelu_tanh_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_gelu_tanh_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// SELU backward, f32. `x>0 → dy*scale`; `x<=0 → dy*scale*alpha*exp(x)`. Saved-x.
pub fn baracuda_kernels_unary_selu_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_selu_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_selu_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// SELU backward, f16.
pub fn baracuda_kernels_unary_selu_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_selu_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_selu_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// SELU backward, bf16.
pub fn baracuda_kernels_unary_selu_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_selu_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_selu_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// SELU backward, f64.
pub fn baracuda_kernels_unary_selu_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_selu_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_selu_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Parameterized-activation BW (hardcoded defaults — LeakyRelu
// α=0.01, ELU α=1.0, Hardshrink λ=0.5, Softshrink λ=0.5). All
// saved-x, no strided BW path (matches the existing activation BW
// pattern). When the parameterized-unary plan ships these get
// re-emitted with the parameter as a runtime arg. ----
/// LeakyReLU backward, f32. `dx = (x > 0) ? dy : dy·α` with α=0.01. Saved-x.
pub fn baracuda_kernels_unary_leaky_relu_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_leaky_relu_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_leaky_relu_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// LeakyReLU backward, f16.
pub fn baracuda_kernels_unary_leaky_relu_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_leaky_relu_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_leaky_relu_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// LeakyReLU backward, bf16.
pub fn baracuda_kernels_unary_leaky_relu_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_leaky_relu_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_leaky_relu_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// LeakyReLU backward, f64.
pub fn baracuda_kernels_unary_leaky_relu_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_leaky_relu_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_leaky_relu_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// ELU backward, f32. `dx = (x > 0) ? dy : dy·α·exp(x)` with α=1.0. Saved-x.
pub fn baracuda_kernels_unary_elu_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_elu_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_elu_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// ELU backward, f16.
pub fn baracuda_kernels_unary_elu_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_elu_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_elu_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// ELU backward, bf16.
pub fn baracuda_kernels_unary_elu_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_elu_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_elu_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// ELU backward, f64.
pub fn baracuda_kernels_unary_elu_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_elu_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_elu_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Hardshrink backward, f32. `dx = (|x| > λ) ? dy : 0` with λ=0.5. Saved-x.
pub fn baracuda_kernels_unary_hardshrink_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardshrink_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_hardshrink_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Hardshrink backward, f16.
pub fn baracuda_kernels_unary_hardshrink_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardshrink_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_hardshrink_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Hardshrink backward, bf16.
pub fn baracuda_kernels_unary_hardshrink_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardshrink_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_hardshrink_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Hardshrink backward, f64.
pub fn baracuda_kernels_unary_hardshrink_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardshrink_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_hardshrink_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Softshrink backward, f32. `dx = (|x| > λ) ? dy : 0` with λ=0.5. Saved-x.
pub fn baracuda_kernels_unary_softshrink_backward_f32_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softshrink_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_softshrink_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Softshrink backward, f16.
pub fn baracuda_kernels_unary_softshrink_backward_f16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softshrink_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_softshrink_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Softshrink backward, bf16.
pub fn baracuda_kernels_unary_softshrink_backward_bf16_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softshrink_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_softshrink_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
/// Softshrink backward, f64.
pub fn baracuda_kernels_unary_softshrink_backward_f64_run(
numel: i64, dy: *const c_void, saved: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softshrink_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_softshrink_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
saved: *const c_void,
dx: *const c_void,
) -> i32;
}
// ============================================================================
// Reduction backward (Phase 4 BW trailblazer)
// ============================================================================
//
// `dx[c] = dy[c with reduce_axis collapsed]` — Sum BW is a pure
// broadcast-copy of dy across the reduced axis. The Rust dispatcher
// constructs the dy strides with `stride[reduce_axis] = 0` so the
// kernel just walks the dx coord space and reads dy via strides.
//
// ABI mirrors the binary strided launcher: `(numel, rank, shape,
// stride_dy, stride_dx, dy, dx, ws, ws_bytes, stream)`. `shape` is the
// full dx shape.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Sum reduction backward, f32. `dx[c] = dy[c_with_reduce_axis_0]`
/// realized via stride-0 broadcast on the reduce axis.
pub fn baracuda_kernels_reduce_sum_backward_f32_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_sum_backward_f32`.
pub fn baracuda_kernels_reduce_sum_backward_f32_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *const c_void,
) -> i32;
/// Sum reduction backward, f16.
pub fn baracuda_kernels_reduce_sum_backward_f16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_sum_backward_f16`.
pub fn baracuda_kernels_reduce_sum_backward_f16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *const c_void,
) -> i32;
/// Sum reduction backward, bf16.
pub fn baracuda_kernels_reduce_sum_backward_bf16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_sum_backward_bf16`.
pub fn baracuda_kernels_reduce_sum_backward_bf16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *const c_void,
) -> i32;
/// Sum reduction backward, f64.
pub fn baracuda_kernels_reduce_sum_backward_f64_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_sum_backward_f64`.
pub fn baracuda_kernels_reduce_sum_backward_f64_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *const c_void,
) -> i32;
/// Mean reduction backward, f32. Same as Sum BW with extra `1/k`
/// scale (`inv_extent` is `1.0 / reduced_extent` computed in f64
/// on the host).
pub fn baracuda_kernels_reduce_mean_backward_f32_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *mut c_void,
inv_extent: f64,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_mean_backward_f32`.
pub fn baracuda_kernels_reduce_mean_backward_f32_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *const c_void,
inv_extent: f64,
) -> i32;
/// Mean reduction backward, f16.
pub fn baracuda_kernels_reduce_mean_backward_f16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *mut c_void,
inv_extent: f64,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_mean_backward_f16`.
pub fn baracuda_kernels_reduce_mean_backward_f16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *const c_void,
inv_extent: f64,
) -> i32;
/// Mean reduction backward, bf16.
pub fn baracuda_kernels_reduce_mean_backward_bf16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *mut c_void,
inv_extent: f64,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_mean_backward_bf16`.
pub fn baracuda_kernels_reduce_mean_backward_bf16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *const c_void,
inv_extent: f64,
) -> i32;
/// Mean reduction backward, f64.
pub fn baracuda_kernels_reduce_mean_backward_f64_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *mut c_void,
inv_extent: f64,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_mean_backward_f64`.
pub fn baracuda_kernels_reduce_mean_backward_f64_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_dx: *const i64,
dy: *const c_void,
dx: *const c_void,
inv_extent: f64,
) -> i32;
// ---- Max / Min reduction backward ----
//
// Single kernel serves BOTH Max BW and Min BW. Compares `x[c]` to
// saved forward output `y[c_reduced]`; matching positions receive
// `dy[c_reduced]`, others get 0. Tie semantic: every tied position
// gets the full gradient (split-across-ties / JAX convention).
/// Max/Min reduction backward, f32.
pub fn baracuda_kernels_reduce_max_min_backward_f32_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_max_min_backward_f32`.
pub fn baracuda_kernels_reduce_max_min_backward_f32_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// Max/Min reduction backward, f16.
pub fn baracuda_kernels_reduce_max_min_backward_f16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_max_min_backward_f16`.
pub fn baracuda_kernels_reduce_max_min_backward_f16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// Max/Min reduction backward, bf16.
pub fn baracuda_kernels_reduce_max_min_backward_bf16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_max_min_backward_bf16`.
pub fn baracuda_kernels_reduce_max_min_backward_bf16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// Max/Min reduction backward, f64.
pub fn baracuda_kernels_reduce_max_min_backward_f64_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_max_min_backward_f64`.
pub fn baracuda_kernels_reduce_max_min_backward_f64_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Prod backward — dual-save (saved x and saved y) ------------
// `dx[c] = dy[c_reduced] * y[c_reduced] / x[c]`.
/// Prod reduction backward, f32.
pub fn baracuda_kernels_reduce_prod_backward_f32_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_prod_backward_f32`.
pub fn baracuda_kernels_reduce_prod_backward_f32_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// Prod reduction backward, f16.
pub fn baracuda_kernels_reduce_prod_backward_f16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_prod_backward_f16`.
pub fn baracuda_kernels_reduce_prod_backward_f16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// Prod reduction backward, bf16.
pub fn baracuda_kernels_reduce_prod_backward_bf16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_prod_backward_bf16`.
pub fn baracuda_kernels_reduce_prod_backward_bf16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// Prod reduction backward, f64.
pub fn baracuda_kernels_reduce_prod_backward_f64_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_prod_backward_f64`.
pub fn baracuda_kernels_reduce_prod_backward_f64_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Norm2 backward — dual-save (saved x and saved y) -----------
// `dx[c] = dy[c_reduced] * x[c] / y[c_reduced]` where
// `y = sqrt(sum(x², axis=k))`.
/// Norm2 reduction backward, f32.
pub fn baracuda_kernels_reduce_norm2_backward_f32_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_norm2_backward_f32`.
pub fn baracuda_kernels_reduce_norm2_backward_f32_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// Norm2 reduction backward, f16.
pub fn baracuda_kernels_reduce_norm2_backward_f16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_norm2_backward_f16`.
pub fn baracuda_kernels_reduce_norm2_backward_f16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// Norm2 reduction backward, bf16.
pub fn baracuda_kernels_reduce_norm2_backward_bf16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_norm2_backward_bf16`.
pub fn baracuda_kernels_reduce_norm2_backward_bf16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// Norm2 reduction backward, f64.
pub fn baracuda_kernels_reduce_norm2_backward_f64_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_norm2_backward_f64`.
pub fn baracuda_kernels_reduce_norm2_backward_f64_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
// ---- LogSumExp backward — `dy * exp(x - y)`, dual-save ----------
// `dx[c] = dy[c_reduced] * exp(x[c] - y[c_reduced])` where
// `y = log(sum(exp(x), axis=k)) + max`. Always numerically safe:
// `x - y ≤ 0`, so `exp(x - y) ∈ (0, 1]`. f16 / bf16 do the exp in
// f32; f32 / f64 use libdevice `expf` / `exp`.
/// LogSumExp reduction backward, f32.
pub fn baracuda_kernels_reduce_logsumexp_backward_f32_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_logsumexp_backward_f32`.
pub fn baracuda_kernels_reduce_logsumexp_backward_f32_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// LogSumExp reduction backward, f16.
pub fn baracuda_kernels_reduce_logsumexp_backward_f16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_logsumexp_backward_f16`.
pub fn baracuda_kernels_reduce_logsumexp_backward_f16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// LogSumExp reduction backward, bf16.
pub fn baracuda_kernels_reduce_logsumexp_backward_bf16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_logsumexp_backward_bf16`.
pub fn baracuda_kernels_reduce_logsumexp_backward_bf16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// LogSumExp reduction backward, f64.
pub fn baracuda_kernels_reduce_logsumexp_backward_f64_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_logsumexp_backward_f64`.
pub fn baracuda_kernels_reduce_logsumexp_backward_f64_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Var / Std backward — Welford BW, f32-only ------------------
// Var BW: `dx[c] = dy[c_reduced] * 2 * (x[c] - mean[c_reduced]) / m`
// Std BW: `dx[c] = dy[c_reduced] * (x[c] - mean[c_reduced]) / (m * y[c_reduced])`
// where `m = max(n - correction, 1)` and `n = reduce_extent`. Mean
// is recomputed inside the kernel (single-pass sum/n over the
// reduce axis on `x`). Saved-x required; saved-y required for Std
// BW and ignored by Var BW (pass null or any valid pointer).
// f32-only — matches the FW Welford scope.
/// Variance reduction backward, f32 (Welford BW).
pub fn baracuda_kernels_reduce_var_backward_f32_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_var_backward_f32`.
pub fn baracuda_kernels_reduce_var_backward_f32_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *const c_void,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
) -> i32;
/// Std-dev reduction backward, f32 (Welford BW + sqrt term).
pub fn baracuda_kernels_reduce_std_backward_f32_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_std_backward_f32`.
pub fn baracuda_kernels_reduce_std_backward_f32_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *const c_void,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
) -> i32;
// ---- Var / Std BW dtype fanout (Phase 4 deferral 4.2 close-out) ----
// Internal accumulation runs at `WelfordAcc<T>`: f32 for
// f16/bf16/f32, f64 for f64. ABI identical to the f32 variants.
/// Variance reduction backward, f16.
pub fn baracuda_kernels_reduce_var_backward_f16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_var_backward_f16`.
pub fn baracuda_kernels_reduce_var_backward_f16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *const c_void,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
) -> i32;
/// Std-dev reduction backward, f16.
pub fn baracuda_kernels_reduce_std_backward_f16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_std_backward_f16`.
pub fn baracuda_kernels_reduce_std_backward_f16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *const c_void,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
) -> i32;
/// Variance reduction backward, bf16.
pub fn baracuda_kernels_reduce_var_backward_bf16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_var_backward_bf16`.
pub fn baracuda_kernels_reduce_var_backward_bf16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *const c_void,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
) -> i32;
/// Std-dev reduction backward, bf16.
pub fn baracuda_kernels_reduce_std_backward_bf16_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_std_backward_bf16`.
pub fn baracuda_kernels_reduce_std_backward_bf16_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *const c_void,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
) -> i32;
/// Variance reduction backward, f64 (Welford BW in f64).
pub fn baracuda_kernels_reduce_var_backward_f64_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_var_backward_f64`.
pub fn baracuda_kernels_reduce_var_backward_f64_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *const c_void,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
) -> i32;
/// Std-dev reduction backward, f64 (Welford BW in f64 + sqrt term).
pub fn baracuda_kernels_reduce_std_backward_f64_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *mut c_void,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `reduce_std_backward_f64`.
pub fn baracuda_kernels_reduce_std_backward_f64_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
dy: *const c_void,
x: *const c_void,
y: *const c_void,
dx: *const c_void,
reduce_axis: i32,
reduce_extent: i32,
reduce_stride_x: i64,
correction: i32,
) -> i32;
}
// ============================================================================
// Elementwise — binary comparison ops (T → bool)
// ============================================================================
//
// Same shape as the binary contig + strided launchers above, but the
// output is `uint8_t` (0 / 1) rather than `T`. The kernel returns the
// comparison result as a bool stored in one byte — PyTorch / NumPy
// convention. The C ABI uses `void*` for the output pointer; the
// kernel wrapper casts it to `uint8_t*` internally.
//
// Full matrix wired: {Eq, Ne, Gt, Ge, Lt, Le} ops × {f32, f16, bf16,
// f64} dtypes × {contig, strided} = 48 launchers (3 symbols per cell:
// `_run`, `_can_implement`, `_strided_run`). NaN handling follows IEEE
// 754: `Eq` / ordered comparisons return 0 when either operand is NaN;
// `Ne` returns 1 (since `NaN != anything`).
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
// --- Eq -----------------------------------------------------------------
/// Binary elementwise `eq`, f32 inputs, u8 output, contig fast path.
///
/// # Safety
/// All device pointers must remain valid for the duration of the
/// launch. `y` must point to at least `numel` `u8`s. The kernel
/// writes only `0u8` and `1u8` to `y`.
pub fn baracuda_kernels_binary_cmp_eq_f32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_cmp_eq_f32`.
pub fn baracuda_kernels_binary_cmp_eq_f32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `eq`, f32 inputs, u8 output, strided path.
///
/// Handles non-contig views (broadcast / transposed / sliced). The
/// output's stride is in u8 elements (one element per byte).
pub fn baracuda_kernels_binary_cmp_eq_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_binary_cmp_eq_f32_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// Binary elementwise `eq`, f16 inputs, u8 output, contig fast path.
///
/// # Safety
/// See `baracuda_kernels_binary_cmp_eq_f32_run`. Inputs are
/// `__half` (one rounding step when storing — but `==` on bit
/// patterns is exact, so the GPU result matches host
/// `half::f16 == half::f16`).
pub fn baracuda_kernels_binary_cmp_eq_f16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_cmp_eq_f16`.
pub fn baracuda_kernels_binary_cmp_eq_f16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `eq`, f16 inputs, u8 output, strided path.
pub fn baracuda_kernels_binary_cmp_eq_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_binary_cmp_eq_f16_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// Binary elementwise `eq`, bf16 inputs, u8 output, contig fast path.
pub fn baracuda_kernels_binary_cmp_eq_bf16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_cmp_eq_bf16`.
pub fn baracuda_kernels_binary_cmp_eq_bf16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `eq`, bf16 inputs, u8 output, strided path.
pub fn baracuda_kernels_binary_cmp_eq_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_binary_cmp_eq_bf16_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// Binary elementwise `eq`, f64 inputs, u8 output, contig fast path.
pub fn baracuda_kernels_binary_cmp_eq_f64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_cmp_eq_f64`.
pub fn baracuda_kernels_binary_cmp_eq_f64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `eq`, f64 inputs, u8 output, strided path.
pub fn baracuda_kernels_binary_cmp_eq_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_binary_cmp_eq_f64_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
}
// --- Ne -------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Binary elementwise `ne`, f32 inputs, u8 output, contig fast path.
///
/// `NaN != anything` returns 1 per IEEE 754.
///
/// # Safety
/// See `baracuda_kernels_binary_cmp_eq_f32_run`.
pub fn baracuda_kernels_binary_cmp_ne_f32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_cmp_ne_f32`.
pub fn baracuda_kernels_binary_cmp_ne_f32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `ne`, f32 inputs, u8 output, strided path.
pub fn baracuda_kernels_binary_cmp_ne_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_binary_cmp_ne_f32_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// Binary elementwise `ne`, f16 inputs, u8 output, contig fast path.
pub fn baracuda_kernels_binary_cmp_ne_f16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_cmp_ne_f16`.
pub fn baracuda_kernels_binary_cmp_ne_f16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `ne`, f16 inputs, u8 output, strided path.
pub fn baracuda_kernels_binary_cmp_ne_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_binary_cmp_ne_f16_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// Binary elementwise `ne`, bf16 inputs, u8 output, contig fast path.
pub fn baracuda_kernels_binary_cmp_ne_bf16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_cmp_ne_bf16`.
pub fn baracuda_kernels_binary_cmp_ne_bf16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `ne`, bf16 inputs, u8 output, strided path.
pub fn baracuda_kernels_binary_cmp_ne_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_binary_cmp_ne_bf16_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// Binary elementwise `ne`, f64 inputs, u8 output, contig fast path.
pub fn baracuda_kernels_binary_cmp_ne_f64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_cmp_ne_f64`.
pub fn baracuda_kernels_binary_cmp_ne_f64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `ne`, f64 inputs, u8 output, strided path.
pub fn baracuda_kernels_binary_cmp_ne_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_binary_cmp_ne_f64_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
}
// --- Gt -------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Binary elementwise `gt` (`a > b`), f32 inputs, u8 output, contig fast path.
///
/// Any comparison involving NaN returns 0 per IEEE 754.
///
/// # Safety
/// See `baracuda_kernels_binary_cmp_eq_f32_run`.
pub fn baracuda_kernels_binary_cmp_gt_f32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_cmp_gt_f32`.
pub fn baracuda_kernels_binary_cmp_gt_f32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `gt`, f32 inputs, u8 output, strided path.
pub fn baracuda_kernels_binary_cmp_gt_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_binary_cmp_gt_f32_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// Binary elementwise `gt`, f16 inputs, u8 output, contig fast path.
pub fn baracuda_kernels_binary_cmp_gt_f16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_cmp_gt_f16`.
pub fn baracuda_kernels_binary_cmp_gt_f16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `gt`, f16 inputs, u8 output, strided path.
pub fn baracuda_kernels_binary_cmp_gt_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_binary_cmp_gt_f16_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// Binary elementwise `gt`, bf16 inputs, u8 output, contig fast path.
pub fn baracuda_kernels_binary_cmp_gt_bf16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_cmp_gt_bf16`.
pub fn baracuda_kernels_binary_cmp_gt_bf16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `gt`, bf16 inputs, u8 output, strided path.
pub fn baracuda_kernels_binary_cmp_gt_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_binary_cmp_gt_bf16_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// Binary elementwise `gt`, f64 inputs, u8 output, contig fast path.
pub fn baracuda_kernels_binary_cmp_gt_f64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_cmp_gt_f64`.
pub fn baracuda_kernels_binary_cmp_gt_f64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `gt`, f64 inputs, u8 output, strided path.
pub fn baracuda_kernels_binary_cmp_gt_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_binary_cmp_gt_f64_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
}
// --- Ge -------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Binary elementwise `ge` (`a >= b`), f32 inputs, u8 output, contig fast path.
///
/// Any comparison involving NaN returns 0 per IEEE 754.
///
/// # Safety
/// See `baracuda_kernels_binary_cmp_eq_f32_run`.
pub fn baracuda_kernels_binary_cmp_ge_f32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_cmp_ge_f32`.
pub fn baracuda_kernels_binary_cmp_ge_f32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `ge`, f32 inputs, u8 output, strided path.
pub fn baracuda_kernels_binary_cmp_ge_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_binary_cmp_ge_f32_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// Binary elementwise `ge`, f16 inputs, u8 output, contig fast path.
pub fn baracuda_kernels_binary_cmp_ge_f16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_cmp_ge_f16`.
pub fn baracuda_kernels_binary_cmp_ge_f16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `ge`, f16 inputs, u8 output, strided path.
pub fn baracuda_kernels_binary_cmp_ge_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_binary_cmp_ge_f16_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// Binary elementwise `ge`, bf16 inputs, u8 output, contig fast path.
pub fn baracuda_kernels_binary_cmp_ge_bf16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_cmp_ge_bf16`.
pub fn baracuda_kernels_binary_cmp_ge_bf16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `ge`, bf16 inputs, u8 output, strided path.
pub fn baracuda_kernels_binary_cmp_ge_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_binary_cmp_ge_bf16_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// Binary elementwise `ge`, f64 inputs, u8 output, contig fast path.
pub fn baracuda_kernels_binary_cmp_ge_f64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_cmp_ge_f64`.
pub fn baracuda_kernels_binary_cmp_ge_f64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `ge`, f64 inputs, u8 output, strided path.
pub fn baracuda_kernels_binary_cmp_ge_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_binary_cmp_ge_f64_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
}
// --- Lt -------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Binary elementwise `lt` (`a < b`), f32 inputs, u8 output, contig fast path.
///
/// Any comparison involving NaN returns 0 per IEEE 754.
///
/// # Safety
/// See `baracuda_kernels_binary_cmp_eq_f32_run`.
pub fn baracuda_kernels_binary_cmp_lt_f32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_cmp_lt_f32`.
pub fn baracuda_kernels_binary_cmp_lt_f32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `lt`, f32 inputs, u8 output, strided path.
pub fn baracuda_kernels_binary_cmp_lt_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_binary_cmp_lt_f32_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// Binary elementwise `lt`, f16 inputs, u8 output, contig fast path.
pub fn baracuda_kernels_binary_cmp_lt_f16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_cmp_lt_f16`.
pub fn baracuda_kernels_binary_cmp_lt_f16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `lt`, f16 inputs, u8 output, strided path.
pub fn baracuda_kernels_binary_cmp_lt_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_binary_cmp_lt_f16_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// Binary elementwise `lt`, bf16 inputs, u8 output, contig fast path.
pub fn baracuda_kernels_binary_cmp_lt_bf16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_cmp_lt_bf16`.
pub fn baracuda_kernels_binary_cmp_lt_bf16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `lt`, bf16 inputs, u8 output, strided path.
pub fn baracuda_kernels_binary_cmp_lt_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_binary_cmp_lt_bf16_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// Binary elementwise `lt`, f64 inputs, u8 output, contig fast path.
pub fn baracuda_kernels_binary_cmp_lt_f64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_cmp_lt_f64`.
pub fn baracuda_kernels_binary_cmp_lt_f64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `lt`, f64 inputs, u8 output, strided path.
pub fn baracuda_kernels_binary_cmp_lt_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_binary_cmp_lt_f64_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
}
// --- Le -------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Binary elementwise `le` (`a <= b`), f32 inputs, u8 output, contig fast path.
///
/// Any comparison involving NaN returns 0 per IEEE 754.
///
/// # Safety
/// See `baracuda_kernels_binary_cmp_eq_f32_run`.
pub fn baracuda_kernels_binary_cmp_le_f32_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_cmp_le_f32`.
pub fn baracuda_kernels_binary_cmp_le_f32_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `le`, f32 inputs, u8 output, strided path.
pub fn baracuda_kernels_binary_cmp_le_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_binary_cmp_le_f32_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// Binary elementwise `le`, f16 inputs, u8 output, contig fast path.
pub fn baracuda_kernels_binary_cmp_le_f16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_cmp_le_f16`.
pub fn baracuda_kernels_binary_cmp_le_f16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `le`, f16 inputs, u8 output, strided path.
pub fn baracuda_kernels_binary_cmp_le_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_binary_cmp_le_f16_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// Binary elementwise `le`, bf16 inputs, u8 output, contig fast path.
pub fn baracuda_kernels_binary_cmp_le_bf16_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_cmp_le_bf16`.
pub fn baracuda_kernels_binary_cmp_le_bf16_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `le`, bf16 inputs, u8 output, strided path.
pub fn baracuda_kernels_binary_cmp_le_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_binary_cmp_le_bf16_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
/// Binary elementwise `le`, f64 inputs, u8 output, contig fast path.
pub fn baracuda_kernels_binary_cmp_le_f64_run(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `binary_cmp_le_f64`.
pub fn baracuda_kernels_binary_cmp_le_f64_can_implement(
numel: i64,
a: *const c_void,
b: *const c_void,
y: *const c_void,
) -> i32;
/// Binary elementwise `le`, f64 inputs, u8 output, strided path.
pub fn baracuda_kernels_binary_cmp_le_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_a: *const i64,
stride_b: *const i64,
stride_y: *const i64,
a: *const c_void,
b: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch check companion.
pub fn baracuda_kernels_binary_cmp_le_f64_strided_can_implement(
numel: i64, rank: i32, shape: *const i32,
stride_a: *const i64, stride_b: *const i64, stride_y: *const i64,
a: *const c_void, b: *const c_void, y: *const c_void,
) -> i32;
}
// ============================================================================
// Elementwise — unary (1→1) ops
// ============================================================================
//
// Same INSTANTIATE-driven kernel-family pattern as the binary path
// above, but for 1→1 ops (`y = f(x)`). Both contig and strided
// variants ship per (op, dtype) cell. The Rust dispatcher picks the
// fast contig path when input + output are both contiguous, else
// strided.
//
// ABI shape mirrors the binary launchers minus the second operand;
// strided variants drop the `stride_b` array too.
//
// Status codes mirror the GEMM family (see crate-level doc).
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `neg`, f32 dtype, contiguous fast path. This
/// is the unary-pointwise trailblazer — its safety contract carries
/// over to every plain unary launcher (`neg`, `abs`, `sqr`, `sqrt`,
/// `rsqrt`, `recip`, `exp`, `log`, `sin`, `cos`, `tan`, `sign`,
/// `floor`, `ceil`, `round`, `erf`, `relu`, `silu`, `gelu`, `tanh`,
/// `sigmoid`, etc.) AND every parameterized-unary launcher
/// (`unary_param_*` family: `powi`, `threshold`, `elu`, `prelu`,
/// `lerp`, etc.) across all dtypes. See also `binary_add_f32_run`
/// for the binary contig aliasing contract and `ternary_clamp_f32_run`
/// for the ternary one.
///
/// # Safety
/// All pointer args must be device-resident and remain valid for the
/// duration of the launch. `stream` must be a live CUDA stream in
/// the current context. `x` and `y` must each point to at least
/// `numel` `float`s of device memory.
///
/// **Aliasing**: aliasing `y` with `x` is safe — each thread reads
/// `x[i]` before writing `y[i]`, with no cross-index dependencies.
/// Callers implementing in-place elementwise unary ops (e.g. Fuel's
/// `Op::ReluInplace`, `Op::SiluInplace`, etc.) can dispatch the
/// forward symbol with `x_ptr == y_ptr` without a dedicated
/// `_inplace_` variant. This contract is stable across baracuda
/// versions.
pub fn baracuda_kernels_unary_neg_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_neg_f32`. Validates
/// the problem size without launching a kernel. Returns the standard
/// status code mapping.
///
/// # Safety
/// Same pointer-validity contract as the corresponding `_run` fn,
/// but no device dereferences occur — only host-side checks.
pub fn baracuda_kernels_unary_neg_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `neg`, f32 dtype, strided path. This is the
/// unary-strided trailblazer — its safety contract (including
/// aliasing) carries over to every other unary strided launcher
/// AND every parameterized-unary strided launcher (`powi`,
/// `threshold`, `elu`, `prelu`, `lerp`) across all dtypes.
///
/// Handles non-contig views (transposed, sliced). Input shape must
/// equal output shape — broadcast is not a meaningful unary
/// semantic and is rejected by the Rust dispatcher upstream.
///
/// # Safety
/// Same device-pointer contract as the contig launcher. `shape`,
/// `stride_x`, `stride_y` are host-side pointers to arrays of at
/// least `rank` elements that must remain valid for the duration
/// of the host-side launch call (the launcher copies them into
/// the kernel parameter block before returning).
///
/// **Aliasing (Phase 62)**: aliasing `y` with `x` is safe IF AND
/// ONLY IF `stride_x == stride_y` element-for-element (use
/// [`baracuda_kernels_types::strides_equal`] to check). With equal
/// strides, each thread reads its own `off` cell then writes the
/// same cell, identical structure to the contig unary case. With
/// unequal strides, different threads can read cells that other
/// threads have already overwritten — silent data corruption. The
/// kernel does no validation; this is the caller's contract. The
/// `__restrict__` qualifiers on the kernel signature are an
/// optimizer hint — they are safe to violate only when the
/// per-thread access pattern remains read-then-write at the same
/// cell, i.e., when strides are equal.
///
/// Additional preconditions (apply with or without aliasing): no
/// zero strides on `y`, and `(shape, stride_y)` must specify a
/// valid permutation (no two linear `i` values mapping to the
/// same `off_y` cell).
///
/// This contract is stable across baracuda versions.
pub fn baracuda_kernels_unary_neg_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_neg_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_neg_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `neg`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as `baracuda_kernels_unary_neg_f32_run`. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_neg_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_neg_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_neg_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `neg`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as `baracuda_kernels_unary_neg_f32_strided_run`. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_neg_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_neg_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_neg_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `neg`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as `baracuda_kernels_unary_neg_f32_run`. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_neg_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_neg_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_neg_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `neg`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as `baracuda_kernels_unary_neg_f32_strided_run`. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_neg_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_neg_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_neg_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `neg`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as `baracuda_kernels_unary_neg_f32_run`. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_neg_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_neg_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_neg_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `neg`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as `baracuda_kernels_unary_neg_f32_strided_run`. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_neg_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_neg_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_neg_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `abs` — `y = |x|` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `abs`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as `baracuda_kernels_unary_neg_f32_run`.
pub fn baracuda_kernels_unary_abs_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_abs_f32`.
///
/// # Safety
/// Host-side checks only — same pointer-validity contract as the `_run` fn.
pub fn baracuda_kernels_unary_abs_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `abs`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as `baracuda_kernels_unary_neg_f32_strided_run`.
pub fn baracuda_kernels_unary_abs_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_abs_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_abs_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `abs`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_abs_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_abs_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_abs_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `abs`, f16 dtype, strided path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_abs_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_abs_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_abs_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `abs`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_abs_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_abs_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_abs_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `abs`, bf16 dtype, strided path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_abs_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_abs_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_abs_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `abs`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_abs_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_abs_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_abs_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `abs`, f64 dtype, strided path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_abs_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_abs_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_abs_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `sign` — `y = sign(x) ∈ {-1, 0, +1}` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `sign`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_sign_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sign_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sign_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sign`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_sign_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sign_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sign_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sign`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_sign_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sign_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sign_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sign`, f16 dtype, strided path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_sign_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sign_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sign_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sign`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_sign_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sign_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sign_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sign`, bf16 dtype, strided path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_sign_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sign_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sign_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sign`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_sign_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sign_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sign_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sign`, f64 dtype, strided path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_sign_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sign_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sign_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `reciprocal` — `y = 1 / x` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `reciprocal`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_reciprocal_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_reciprocal_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_reciprocal_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `reciprocal`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_reciprocal_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_reciprocal_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_reciprocal_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `reciprocal`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_reciprocal_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_reciprocal_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_reciprocal_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `reciprocal`, f16 dtype, strided path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_reciprocal_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_reciprocal_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_reciprocal_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `reciprocal`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_reciprocal_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_reciprocal_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_reciprocal_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `reciprocal`, bf16 dtype, strided path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_reciprocal_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_reciprocal_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_reciprocal_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `reciprocal`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_reciprocal_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_reciprocal_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_reciprocal_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `reciprocal`, f64 dtype, strided path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_reciprocal_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_reciprocal_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_reciprocal_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `square` — `y = x * x` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `square`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_square_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_square_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_square_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `square`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_square_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_square_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_square_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `square`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_square_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_square_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_square_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `square`, f16 dtype, strided path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_square_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_square_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_square_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `square`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_square_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_square_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_square_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `square`, bf16 dtype, strided path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_square_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_square_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_square_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `square`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_square_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_square_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_square_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `square`, f64 dtype, strided path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_square_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_square_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_square_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `cube` — `y = x * x * x` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `cube`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_cube_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cube_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cube_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `cube`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_cube_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cube_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cube_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `cube`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_cube_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cube_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cube_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `cube`, f16 dtype, strided path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_cube_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cube_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cube_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `cube`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_cube_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cube_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cube_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `cube`, bf16 dtype, strided path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_cube_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cube_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cube_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `cube`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_cube_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cube_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cube_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `cube`, f64 dtype, strided path.
///
/// # Safety
/// Same as the f32 variant; `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_cube_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cube_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cube_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `sqrt` — `y = sqrt(x)` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `sqrt`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_sqrt_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sqrt_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sqrt_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sqrt`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_sqrt_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sqrt_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sqrt_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sqrt`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_sqrt_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sqrt_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sqrt_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sqrt`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_sqrt_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sqrt_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sqrt_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sqrt`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_sqrt_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sqrt_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sqrt_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sqrt`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_sqrt_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sqrt_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sqrt_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sqrt`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_sqrt_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sqrt_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sqrt_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sqrt`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_sqrt_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sqrt_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sqrt_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `rsqrt` — `y = 1 / sqrt(x)` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `rsqrt`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_rsqrt_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_rsqrt_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_rsqrt_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `rsqrt`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_rsqrt_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_rsqrt_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_rsqrt_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `rsqrt`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_rsqrt_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_rsqrt_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_rsqrt_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `rsqrt`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_rsqrt_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_rsqrt_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_rsqrt_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `rsqrt`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_rsqrt_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_rsqrt_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_rsqrt_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `rsqrt`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_rsqrt_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_rsqrt_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_rsqrt_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `rsqrt`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_rsqrt_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_rsqrt_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_rsqrt_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `rsqrt`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_rsqrt_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_rsqrt_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_rsqrt_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `exp` — `y = exp(x)` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `exp`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_exp_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_exp_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_exp_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `exp`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_exp_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_exp_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_exp_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `exp`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_exp_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_exp_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_exp_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `exp`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_exp_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_exp_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_exp_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `exp`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_exp_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_exp_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_exp_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `exp`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_exp_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_exp_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_exp_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `exp`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_exp_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_exp_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_exp_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `exp`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_exp_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_exp_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_exp_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `expm1` — `y = exp(x) - 1` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `expm1`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_expm1_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_expm1_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_expm1_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `expm1`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_expm1_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_expm1_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_expm1_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `expm1`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_expm1_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_expm1_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_expm1_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `expm1`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_expm1_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_expm1_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_expm1_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `expm1`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_expm1_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_expm1_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_expm1_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `expm1`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_expm1_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_expm1_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_expm1_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `expm1`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_expm1_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_expm1_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_expm1_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `expm1`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_expm1_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_expm1_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_expm1_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `log` — `y = ln(x)` (natural log) across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `log`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_log_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `log`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_log_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `log`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_log_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `log`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_log_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `log`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_log_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `log`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_log_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `log`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_log_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `log`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_log_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `log1p` — `y = ln(1 + x)` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `log1p`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_log1p_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log1p_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log1p_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `log1p`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_log1p_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log1p_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log1p_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `log1p`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_log1p_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log1p_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log1p_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `log1p`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_log1p_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log1p_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log1p_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `log1p`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_log1p_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log1p_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log1p_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `log1p`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_log1p_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log1p_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log1p_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `log1p`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_log1p_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log1p_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log1p_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `log1p`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_log1p_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log1p_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log1p_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `sin` — `y = sin(x)` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `sin`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_sin_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sin_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sin_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sin`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_sin_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sin_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sin_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sin`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_sin_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sin_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sin_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sin`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_sin_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sin_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sin_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sin`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_sin_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sin_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sin_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sin`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_sin_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sin_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sin_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sin`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_sin_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sin_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sin_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sin`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_sin_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sin_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sin_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `cos` — `y = cos(x)` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `cos`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_cos_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cos_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cos_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `cos`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_cos_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cos_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cos_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `cos`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_cos_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cos_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cos_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `cos`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_cos_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cos_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cos_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `cos`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_cos_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cos_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cos_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `cos`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_cos_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cos_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cos_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `cos`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_cos_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cos_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cos_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `cos`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_cos_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cos_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cos_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `tan` — `y = tan(x)` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `tan`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_tan_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tan_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_tan_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `tan`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_tan_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tan_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_tan_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `tan`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_tan_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tan_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_tan_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `tan`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_tan_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tan_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_tan_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `tan`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_tan_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tan_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_tan_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `tan`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_tan_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tan_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_tan_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `tan`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_tan_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tan_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_tan_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `tan`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_tan_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tan_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_tan_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `sinh` — `y = sinh(x)` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `sinh`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_sinh_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sinh_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sinh_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sinh`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_sinh_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sinh_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sinh_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sinh`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_sinh_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sinh_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sinh_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sinh`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_sinh_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sinh_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sinh_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sinh`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_sinh_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sinh_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sinh_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sinh`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_sinh_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sinh_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sinh_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sinh`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_sinh_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sinh_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sinh_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sinh`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_sinh_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sinh_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sinh_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `cosh` — `y = cosh(x)` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `cosh`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_cosh_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cosh_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cosh_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `cosh`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_cosh_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cosh_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cosh_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `cosh`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_cosh_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cosh_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cosh_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `cosh`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_cosh_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cosh_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cosh_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `cosh`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_cosh_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cosh_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cosh_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `cosh`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_cosh_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cosh_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cosh_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `cosh`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_cosh_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cosh_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cosh_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `cosh`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_cosh_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cosh_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cosh_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `tanh` — `y = tanh(x)` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `tanh`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_tanh_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tanh_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_tanh_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `tanh`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_tanh_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tanh_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_tanh_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `tanh`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_tanh_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tanh_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_tanh_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `tanh`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_tanh_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tanh_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_tanh_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `tanh`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_tanh_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tanh_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_tanh_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `tanh`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_tanh_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tanh_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_tanh_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `tanh`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_tanh_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tanh_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_tanh_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `tanh`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_tanh_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tanh_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_tanh_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `relu` — `y = relu(x)` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `relu`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_relu_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_relu_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_relu_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `relu`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_relu_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_relu_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_relu_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `relu`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_relu_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_relu_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_relu_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `relu`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_relu_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_relu_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_relu_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `relu`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_relu_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_relu_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_relu_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `relu`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_relu_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_relu_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_relu_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `relu`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_relu_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_relu_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_relu_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `relu`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_relu_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_relu_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_relu_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `gelu` — `y = gelu(x)` across f32 / f16 / bf16 / f64.
//
// ERF-EXACT gelu (`0.5·x·(1+erf(x/√2))`) — NOT the tanh approximation.
// `unary_gelu_erf_*` is a bit-identical alias; use `unary_gelu_tanh_*`
// for the tanh flavor.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `gelu`, f32 dtype, contiguous fast path.
///
/// ERF-EXACT gelu (`0.5·x·(1+erf(x/√2))`) — NOT the tanh
/// approximation. `unary_gelu_erf_*` is a bit-identical alias; use
/// `unary_gelu_tanh_*` for the tanh flavor.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_gelu_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_gelu_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_gelu_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `gelu`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_gelu_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_gelu_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_gelu_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `gelu`, f16 dtype, contiguous fast path.
///
/// ERF-EXACT gelu (`0.5·x·(1+erf(x/√2))`) — NOT the tanh
/// approximation. `unary_gelu_erf_*` is a bit-identical alias; use
/// `unary_gelu_tanh_*` for the tanh flavor.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_gelu_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_gelu_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_gelu_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `gelu`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_gelu_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_gelu_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_gelu_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `gelu`, bf16 dtype, contiguous fast path.
///
/// ERF-EXACT gelu (`0.5·x·(1+erf(x/√2))`) — NOT the tanh
/// approximation. `unary_gelu_erf_*` is a bit-identical alias; use
/// `unary_gelu_tanh_*` for the tanh flavor.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_gelu_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_gelu_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_gelu_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `gelu`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_gelu_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_gelu_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_gelu_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `gelu`, f64 dtype, contiguous fast path.
///
/// ERF-EXACT gelu (`0.5·x·(1+erf(x/√2))`) — NOT the tanh
/// approximation. `unary_gelu_erf_*` is a bit-identical alias; use
/// `unary_gelu_tanh_*` for the tanh flavor.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_gelu_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_gelu_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_gelu_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `gelu`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_gelu_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_gelu_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_gelu_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `gelu_tanh` — `y = gelu_tanh(x)` across f32 / f16 / bf16 / f64.
//
// Tanh APPROXIMATION of gelu (`0.5·x·(1+tanh(√(2/π)·(x+0.044715·x³)))`)
// — diverges from the erf-exact `unary_gelu_*` by up to ~1e-4.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `gelu_tanh`, f32 dtype, contiguous fast path.
///
/// Tanh APPROXIMATION of gelu — diverges from the erf-exact
/// `unary_gelu_*` by up to ~1e-4.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_gelu_tanh_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_gelu_tanh_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_gelu_tanh_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `gelu_tanh`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_gelu_tanh_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_gelu_tanh_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_gelu_tanh_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `gelu_tanh`, f16 dtype, contiguous fast path.
///
/// Tanh APPROXIMATION of gelu — diverges from the erf-exact
/// `unary_gelu_*` by up to ~1e-4.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_gelu_tanh_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_gelu_tanh_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_gelu_tanh_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `gelu_tanh`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_gelu_tanh_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_gelu_tanh_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_gelu_tanh_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `gelu_tanh`, bf16 dtype, contiguous fast path.
///
/// Tanh APPROXIMATION of gelu — diverges from the erf-exact
/// `unary_gelu_*` by up to ~1e-4.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_gelu_tanh_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_gelu_tanh_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_gelu_tanh_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `gelu_tanh`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_gelu_tanh_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_gelu_tanh_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_gelu_tanh_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `gelu_tanh`, f64 dtype, contiguous fast path.
///
/// Tanh APPROXIMATION of gelu — diverges from the erf-exact
/// `unary_gelu_*` by up to ~1e-4.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_gelu_tanh_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_gelu_tanh_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_gelu_tanh_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `gelu_tanh`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_gelu_tanh_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_gelu_tanh_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_gelu_tanh_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `silu` — `y = silu(x)` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `silu`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_silu_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_silu_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_silu_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `silu`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_silu_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_silu_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_silu_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `silu`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_silu_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_silu_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_silu_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `silu`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_silu_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_silu_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_silu_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `silu`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_silu_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_silu_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_silu_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `silu`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_silu_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_silu_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_silu_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `silu`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_silu_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_silu_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_silu_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `silu`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_silu_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_silu_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_silu_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `mish` — `y = mish(x)` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `mish`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_mish_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_mish_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_mish_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `mish`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_mish_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_mish_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_mish_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `mish`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_mish_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_mish_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_mish_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `mish`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_mish_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_mish_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_mish_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `mish`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_mish_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_mish_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_mish_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `mish`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_mish_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_mish_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_mish_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `mish`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_mish_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_mish_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_mish_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `mish`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_mish_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_mish_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_mish_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `sigmoid` — `y = sigmoid(x)` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `sigmoid`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_sigmoid_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sigmoid_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sigmoid_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sigmoid`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_sigmoid_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sigmoid_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sigmoid_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sigmoid`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_sigmoid_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sigmoid_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sigmoid_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sigmoid`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_sigmoid_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sigmoid_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sigmoid_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sigmoid`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_sigmoid_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sigmoid_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sigmoid_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sigmoid`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_sigmoid_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sigmoid_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sigmoid_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sigmoid`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_sigmoid_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sigmoid_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sigmoid_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `sigmoid`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_sigmoid_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_sigmoid_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_sigmoid_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `softplus` — `y = softplus(x)` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `softplus`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_softplus_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softplus_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_softplus_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `softplus`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_softplus_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softplus_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_softplus_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `softplus`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_softplus_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softplus_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_softplus_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `softplus`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_softplus_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softplus_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_softplus_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `softplus`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_softplus_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softplus_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_softplus_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `softplus`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_softplus_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softplus_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_softplus_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `softplus`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_softplus_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softplus_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_softplus_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `softplus`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_softplus_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softplus_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_softplus_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `hardswish` — `y = hardswish(x)` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `hardswish`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_hardswish_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardswish_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardswish_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `hardswish`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_hardswish_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardswish_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardswish_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `hardswish`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_hardswish_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardswish_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardswish_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `hardswish`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_hardswish_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardswish_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardswish_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `hardswish`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_hardswish_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardswish_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardswish_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `hardswish`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_hardswish_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardswish_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardswish_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `hardswish`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_hardswish_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardswish_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardswish_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `hardswish`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_hardswish_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardswish_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardswish_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `hardsigmoid` — `y = hardsigmoid(x)` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `hardsigmoid`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_hardsigmoid_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardsigmoid_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardsigmoid_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `hardsigmoid`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_hardsigmoid_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardsigmoid_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardsigmoid_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `hardsigmoid`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_hardsigmoid_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardsigmoid_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardsigmoid_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `hardsigmoid`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_hardsigmoid_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardsigmoid_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardsigmoid_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `hardsigmoid`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_hardsigmoid_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardsigmoid_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardsigmoid_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `hardsigmoid`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_hardsigmoid_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardsigmoid_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardsigmoid_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `hardsigmoid`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_hardsigmoid_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardsigmoid_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardsigmoid_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `hardsigmoid`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_hardsigmoid_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardsigmoid_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardsigmoid_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `hardtanh` — `y = hardtanh(x)` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `hardtanh`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_hardtanh_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardtanh_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardtanh_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `hardtanh`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_hardtanh_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardtanh_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardtanh_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `hardtanh`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_hardtanh_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardtanh_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardtanh_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `hardtanh`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_hardtanh_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardtanh_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardtanh_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `hardtanh`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_hardtanh_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardtanh_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardtanh_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `hardtanh`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_hardtanh_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardtanh_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardtanh_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `hardtanh`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_hardtanh_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardtanh_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardtanh_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `hardtanh`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_hardtanh_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardtanh_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardtanh_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `cbrt` — cube root across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `cbrt`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_cbrt_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cbrt_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cbrt_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `cbrt`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_cbrt_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cbrt_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cbrt_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `cbrt`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_cbrt_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cbrt_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cbrt_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `cbrt`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_cbrt_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cbrt_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cbrt_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `cbrt`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_cbrt_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cbrt_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cbrt_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `cbrt`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_cbrt_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cbrt_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cbrt_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `cbrt`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_cbrt_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cbrt_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cbrt_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `cbrt`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_cbrt_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_cbrt_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_cbrt_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `exp2` — `y = 2^x` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `exp2`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_exp2_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_exp2_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_exp2_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `exp2`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_exp2_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_exp2_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_exp2_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `exp2`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_exp2_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_exp2_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_exp2_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `exp2`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_exp2_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_exp2_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_exp2_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `exp2`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_exp2_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_exp2_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_exp2_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `exp2`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_exp2_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_exp2_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_exp2_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `exp2`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_exp2_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_exp2_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_exp2_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `exp2`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_exp2_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_exp2_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_exp2_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `log2` — base-2 log across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `log2`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_log2_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log2_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log2_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `log2`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_log2_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log2_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log2_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `log2`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_log2_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log2_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log2_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `log2`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_log2_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log2_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log2_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `log2`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_log2_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log2_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log2_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `log2`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_log2_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log2_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log2_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `log2`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_log2_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log2_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log2_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `log2`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_log2_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log2_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log2_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `log10` — base-10 log across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `log10`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_log10_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log10_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log10_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `log10`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_log10_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log10_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log10_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `log10`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_log10_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log10_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log10_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `log10`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_log10_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log10_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log10_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `log10`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_log10_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log10_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log10_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `log10`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_log10_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log10_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log10_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `log10`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_log10_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log10_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log10_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `log10`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_log10_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_log10_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_log10_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `asin` — inverse sine across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `asin`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_asin_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_asin_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_asin_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `asin`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_asin_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_asin_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_asin_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `asin`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_asin_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_asin_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_asin_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `asin`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_asin_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_asin_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_asin_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `asin`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_asin_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_asin_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_asin_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `asin`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_asin_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_asin_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_asin_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `asin`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_asin_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_asin_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_asin_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `asin`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_asin_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_asin_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_asin_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `acos` — inverse cosine across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `acos`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_acos_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_acos_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_acos_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `acos`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_acos_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_acos_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_acos_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `acos`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_acos_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_acos_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_acos_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `acos`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_acos_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_acos_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_acos_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `acos`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_acos_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_acos_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_acos_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `acos`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_acos_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_acos_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_acos_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `acos`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_acos_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_acos_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_acos_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `acos`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_acos_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_acos_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_acos_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `atan` — inverse tangent across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `atan`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_atan_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_atan_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_atan_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `atan`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_atan_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_atan_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_atan_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `atan`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_atan_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_atan_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_atan_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `atan`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_atan_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_atan_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_atan_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `atan`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_atan_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_atan_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_atan_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `atan`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_atan_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_atan_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_atan_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `atan`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_atan_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_atan_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_atan_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `atan`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_atan_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_atan_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_atan_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `asinh` — inverse hyperbolic sine across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `asinh`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_asinh_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_asinh_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_asinh_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `asinh`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_asinh_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_asinh_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_asinh_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `asinh`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_asinh_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_asinh_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_asinh_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `asinh`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_asinh_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_asinh_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_asinh_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `asinh`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_asinh_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_asinh_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_asinh_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `asinh`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_asinh_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_asinh_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_asinh_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `asinh`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_asinh_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_asinh_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_asinh_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `asinh`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_asinh_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_asinh_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_asinh_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `acosh` — inverse hyperbolic cosine across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `acosh`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_acosh_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_acosh_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_acosh_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `acosh`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_acosh_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_acosh_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_acosh_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `acosh`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_acosh_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_acosh_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_acosh_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `acosh`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_acosh_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_acosh_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_acosh_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `acosh`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_acosh_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_acosh_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_acosh_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `acosh`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_acosh_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_acosh_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_acosh_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `acosh`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_acosh_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_acosh_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_acosh_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `acosh`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_acosh_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_acosh_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_acosh_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `atanh` — inverse hyperbolic tangent across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `atanh`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_atanh_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_atanh_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_atanh_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `atanh`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_atanh_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_atanh_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_atanh_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `atanh`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_atanh_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_atanh_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_atanh_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `atanh`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_atanh_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_atanh_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_atanh_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `atanh`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_atanh_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_atanh_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_atanh_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `atanh`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_atanh_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_atanh_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_atanh_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `atanh`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_atanh_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_atanh_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_atanh_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `atanh`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_atanh_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_atanh_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_atanh_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `floor` — round toward -infinity across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `floor`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_floor_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_floor_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_floor_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `floor`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_floor_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_floor_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_floor_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `floor`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_floor_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_floor_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_floor_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `floor`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_floor_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_floor_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_floor_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `floor`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_floor_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_floor_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_floor_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `floor`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_floor_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_floor_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_floor_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `floor`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_floor_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_floor_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_floor_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `floor`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_floor_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_floor_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_floor_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `ceil` — round toward +infinity across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `ceil`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_ceil_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_ceil_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_ceil_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `ceil`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_ceil_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_ceil_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_ceil_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `ceil`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_ceil_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_ceil_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_ceil_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `ceil`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_ceil_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_ceil_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_ceil_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `ceil`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_ceil_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_ceil_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_ceil_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `ceil`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_ceil_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_ceil_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_ceil_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `ceil`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_ceil_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_ceil_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_ceil_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `ceil`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_ceil_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_ceil_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_ceil_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `round` — round-half-to-even (PyTorch convention) across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `round`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_round_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_round_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_round_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `round`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_round_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_round_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_round_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `round`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_round_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_round_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_round_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `round`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_round_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_round_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_round_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `round`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_round_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_round_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_round_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `round`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_round_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_round_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_round_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `round`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_round_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_round_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_round_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `round`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_round_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_round_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_round_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `trunc` — round toward zero across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `trunc`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_trunc_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_trunc_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_trunc_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `trunc`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_trunc_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_trunc_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_trunc_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `trunc`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_trunc_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_trunc_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_trunc_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `trunc`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_trunc_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_trunc_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_trunc_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `trunc`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_trunc_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_trunc_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_trunc_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `trunc`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_trunc_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_trunc_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_trunc_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `trunc`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_trunc_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_trunc_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_trunc_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `trunc`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_trunc_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_trunc_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_trunc_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `frac` — fractional part (sign of x) across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `frac`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_frac_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_frac_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_frac_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `frac`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_frac_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_frac_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_frac_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `frac`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_frac_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_frac_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_frac_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `frac`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_frac_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_frac_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_frac_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `frac`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_frac_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_frac_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_frac_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `frac`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_frac_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_frac_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_frac_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `frac`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_frac_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_frac_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_frac_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `frac`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_frac_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_frac_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_frac_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `erf` — `y = erf(x)` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `erf`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_erf_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_erf_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_erf_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `erf`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_erf_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_erf_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_erf_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `erf`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_erf_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_erf_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_erf_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `erf`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_erf_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_erf_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_erf_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `erf`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_erf_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_erf_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_erf_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `erf`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_erf_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_erf_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_erf_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `erf`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_erf_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_erf_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_erf_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `erf`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_erf_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_erf_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_erf_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `erfc` — `y = erfc(x)` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `erfc`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_erfc_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_erfc_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_erfc_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `erfc`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_erfc_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_erfc_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_erfc_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `erfc`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_erfc_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_erfc_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_erfc_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `erfc`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_erfc_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_erfc_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_erfc_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `erfc`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_erfc_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_erfc_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_erfc_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `erfc`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_erfc_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_erfc_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_erfc_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `erfc`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_erfc_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_erfc_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_erfc_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `erfc`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_erfc_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_erfc_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_erfc_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `lgamma` — `y = lgamma(x)` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `lgamma`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_lgamma_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_lgamma_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_lgamma_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `lgamma`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_lgamma_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_lgamma_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_lgamma_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `lgamma`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_lgamma_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_lgamma_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_lgamma_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `lgamma`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_lgamma_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_lgamma_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_lgamma_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `lgamma`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_lgamma_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_lgamma_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_lgamma_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `lgamma`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_lgamma_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_lgamma_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_lgamma_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `lgamma`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_lgamma_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_lgamma_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_lgamma_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `lgamma`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_lgamma_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_lgamma_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_lgamma_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `logit` — `y = logit(x)` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `logit`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_logit_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_logit_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_logit_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `logit`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_logit_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_logit_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_logit_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `logit`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_logit_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_logit_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_logit_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `logit`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_logit_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_logit_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_logit_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `logit`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_logit_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_logit_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_logit_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `logit`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_logit_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_logit_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_logit_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `logit`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_logit_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_logit_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_logit_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `logit`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_logit_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_logit_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_logit_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `softsign` — `y = softsign(x)` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `softsign`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_softsign_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softsign_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_softsign_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `softsign`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_softsign_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softsign_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_softsign_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `softsign`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_softsign_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softsign_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_softsign_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `softsign`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_softsign_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softsign_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_softsign_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `softsign`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_softsign_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softsign_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_softsign_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `softsign`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_softsign_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softsign_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_softsign_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `softsign`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_softsign_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softsign_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_softsign_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `softsign`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_softsign_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softsign_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_softsign_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `tanhshrink` — `y = tanhshrink(x)` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `tanhshrink`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_tanhshrink_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tanhshrink_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_tanhshrink_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `tanhshrink`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_tanhshrink_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tanhshrink_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_tanhshrink_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `tanhshrink`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_tanhshrink_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tanhshrink_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_tanhshrink_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `tanhshrink`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_tanhshrink_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tanhshrink_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_tanhshrink_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `tanhshrink`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_tanhshrink_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tanhshrink_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_tanhshrink_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `tanhshrink`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_tanhshrink_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tanhshrink_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_tanhshrink_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `tanhshrink`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_tanhshrink_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tanhshrink_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_tanhshrink_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `tanhshrink`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_tanhshrink_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_tanhshrink_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_tanhshrink_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `relu6` — `y = relu6(x)` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `relu6`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_relu6_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_relu6_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_relu6_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `relu6`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_relu6_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_relu6_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_relu6_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `relu6`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_relu6_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_relu6_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_relu6_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `relu6`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_relu6_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_relu6_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_relu6_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `relu6`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_relu6_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_relu6_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_relu6_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `relu6`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_relu6_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_relu6_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_relu6_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `relu6`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_relu6_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_relu6_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_relu6_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `relu6`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_relu6_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_relu6_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_relu6_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Unary `selu` — `y = selu(x)` across f32 / f16 / bf16 / f64.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// Unary elementwise `selu`, f32 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer.
pub fn baracuda_kernels_unary_selu_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_selu_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_selu_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `selu`, f32 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher.
pub fn baracuda_kernels_unary_selu_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_selu_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_selu_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `selu`, f16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_selu_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_selu_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_selu_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `selu`, f16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_selu_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_selu_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_selu_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `selu`, bf16 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_selu_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_selu_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_selu_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `selu`, bf16 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_selu_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_selu_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_selu_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `selu`, f64 dtype, contiguous fast path.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-neg trailblazer. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_selu_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_selu_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_selu_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `selu`, f64 dtype, strided path.
///
/// # Safety
/// Same contract as the unary-neg strided launcher. `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_selu_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_selu_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_selu_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// ----------------------------------------------------------------------------
// Parameterized-activation FW fanout — LeakyRelu / ELU / Hardshrink /
// Softshrink across f32 / f16 / bf16 / f64. Parameters are hardcoded
// (LeakyRelu α=0.01, ELU α=1.0, Hardshrink λ=0.5, Softshrink λ=0.5) to
// match PyTorch defaults. When the parameterized-unary plan ships these
// re-emit with the parameter as a runtime arg — same dispatch shape, no
// extern signature change.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
// ---- LeakyReLU (α=0.01) ----
/// Unary elementwise `leaky_relu` (α=0.01), f32, contig.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-relu trailblazer.
pub fn baracuda_kernels_unary_leaky_relu_f32_run(
numel: i64, x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_leaky_relu_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_leaky_relu_f32_can_implement(
numel: i64, x: *const c_void, y: *const c_void,
) -> i32;
/// Unary elementwise `leaky_relu` (α=0.01), f32, strided.
///
/// # Safety
/// Same contract as the unary-relu strided launcher.
pub fn baracuda_kernels_unary_leaky_relu_f32_strided_run(
numel: i64, rank: i32, shape: *const i32, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_leaky_relu_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_leaky_relu_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `leaky_relu` (α=0.01), f16, contig.
///
/// # Safety
/// `x` / `y` point to `__half` storage; same contract as the unary-relu trailblazer.
pub fn baracuda_kernels_unary_leaky_relu_f16_run(
numel: i64, x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_leaky_relu_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_leaky_relu_f16_can_implement(
numel: i64, x: *const c_void, y: *const c_void,
) -> i32;
/// Unary elementwise `leaky_relu` (α=0.01), f16, strided.
///
/// # Safety
/// `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_leaky_relu_f16_strided_run(
numel: i64, rank: i32, shape: *const i32, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_leaky_relu_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_leaky_relu_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `leaky_relu` (α=0.01), bf16, contig.
///
/// # Safety
/// `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_leaky_relu_bf16_run(
numel: i64, x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_leaky_relu_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_leaky_relu_bf16_can_implement(
numel: i64, x: *const c_void, y: *const c_void,
) -> i32;
/// Unary elementwise `leaky_relu` (α=0.01), bf16, strided.
///
/// # Safety
/// `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_leaky_relu_bf16_strided_run(
numel: i64, rank: i32, shape: *const i32, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_leaky_relu_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_leaky_relu_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `leaky_relu` (α=0.01), f64, contig.
///
/// # Safety
/// `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_leaky_relu_f64_run(
numel: i64, x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_leaky_relu_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_leaky_relu_f64_can_implement(
numel: i64, x: *const c_void, y: *const c_void,
) -> i32;
/// Unary elementwise `leaky_relu` (α=0.01), f64, strided.
///
/// # Safety
/// `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_leaky_relu_f64_strided_run(
numel: i64, rank: i32, shape: *const i32, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_leaky_relu_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_leaky_relu_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
// ---- ELU (α = runtime parameter) ----
//
// Phase 31 BREAKING CHANGE: the ELU FFI symbols gained an
// `alpha: f32` parameter (Fuel Phase 6c.2 storage.rs unblock). The
// old hardcoded-α=1.0 signature is gone — callers must thread α
// through explicitly. For PyTorch's `nn.ELU` default behaviour
// pass `alpha = 1.0`. f64 widens internally from the f32 ABI.
/// Unary elementwise `elu(x; α) = x if x>0 else α·(exp(x)-1)`, f32, contig.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-relu trailblazer.
/// `alpha` is the negative-branch scale; pass `1.0` for the
/// PyTorch default.
pub fn baracuda_kernels_unary_elu_f32_run(
numel: i64, x: *const c_void, y: *mut c_void,
alpha: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_elu_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_elu_f32_can_implement(
numel: i64, x: *const c_void, y: *const c_void,
) -> i32;
/// Unary elementwise `elu(x; α)`, f32, strided.
///
/// # Safety
/// Same contract as the unary-relu strided launcher. `alpha` carries
/// the negative-branch scale.
pub fn baracuda_kernels_unary_elu_f32_strided_run(
numel: i64, rank: i32, shape: *const i32, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
alpha: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_unary_elu_f32_strided`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_unary_elu_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
alpha: f32,
) -> i32;
/// Unary elementwise `elu(x; α)`, f16, contig.
///
/// # Safety
/// `x` / `y` point to `__half` storage. `alpha` is f32 and is
/// applied inside the f32 detour.
pub fn baracuda_kernels_unary_elu_f16_run(
numel: i64, x: *const c_void, y: *mut c_void,
alpha: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_elu_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_elu_f16_can_implement(
numel: i64, x: *const c_void, y: *const c_void,
) -> i32;
/// Unary elementwise `elu(x; α)`, f16, strided.
///
/// # Safety
/// `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_elu_f16_strided_run(
numel: i64, rank: i32, shape: *const i32, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
alpha: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_unary_elu_f16_strided`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_unary_elu_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
alpha: f32,
) -> i32;
/// Unary elementwise `elu(x; α)`, bf16, contig.
///
/// # Safety
/// `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_elu_bf16_run(
numel: i64, x: *const c_void, y: *mut c_void,
alpha: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_elu_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_elu_bf16_can_implement(
numel: i64, x: *const c_void, y: *const c_void,
) -> i32;
/// Unary elementwise `elu(x; α)`, bf16, strided.
///
/// # Safety
/// `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_elu_bf16_strided_run(
numel: i64, rank: i32, shape: *const i32, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
alpha: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_unary_elu_bf16_strided`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_unary_elu_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
alpha: f32,
) -> i32;
/// Unary elementwise `elu(x; α)`, f64, contig.
///
/// # Safety
/// `x` / `y` point to `double` storage. `alpha` is f32 in the ABI
/// and widened to double inside the kernel — this is the safest
/// choice because most Fuel call sites pass `alpha` as `f32` and
/// PyTorch's `alpha` default (1.0) round-trips exactly.
pub fn baracuda_kernels_unary_elu_f64_run(
numel: i64, x: *const c_void, y: *mut c_void,
alpha: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_elu_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_elu_f64_can_implement(
numel: i64, x: *const c_void, y: *const c_void,
) -> i32;
/// Unary elementwise `elu(x; α)`, f64, strided.
///
/// # Safety
/// `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_elu_f64_strided_run(
numel: i64, rank: i32, shape: *const i32, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
alpha: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_unary_elu_f64_strided`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_unary_elu_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
alpha: f32,
) -> i32;
// ---- Phase 31 — PowF (float-exponent power) ----
//
// `y = pow(x, exponent)`. Distinct from `unary_powi_*` (integer
// exponent via power-by-squaring) and from `binary_pow_*` (per-
// element f32 exponent tensor). FFI shape: single `exponent: f32`
// parameter, NOT the unary-param `(p0, p1)` slot — Fuel ask.
/// Unary elementwise `pow(x, exponent)`, f32, contig.
///
/// # Safety
/// Same device-pointer / stream contract as `unary_relu_*_run`.
/// `exponent` is broadcast over every element. f32 uses `__powf`
/// (≤4 ULP); fallback to bare `powf` if you need strict-ULP
/// guarantees.
pub fn baracuda_kernels_unary_powf_f32_run(
numel: i64, x: *const c_void, y: *mut c_void,
exponent: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `unary_powf_f32`.
pub fn baracuda_kernels_unary_powf_f32_can_implement(
numel: i64, x: *const c_void, y: *const c_void,
) -> i32;
/// `unary_powf`, f32, strided sibling.
pub fn baracuda_kernels_unary_powf_f32_strided_run(
numel: i64, rank: i32, shape: *const i32, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
exponent: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_unary_powf_f32_strided`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_unary_powf_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
exponent: f32,
) -> i32;
/// `unary_powf`, f16, contig. f32 detour.
pub fn baracuda_kernels_unary_powf_f16_run(
numel: i64, x: *const c_void, y: *mut c_void,
exponent: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_unary_powf_f16_can_implement` (baracuda kernels unary powf f16 can implement).
pub fn baracuda_kernels_unary_powf_f16_can_implement(
numel: i64, x: *const c_void, y: *const c_void,
) -> i32;
/// `baracuda_kernels_unary_powf_f16_strided_run` (baracuda kernels unary powf f16 strided run).
pub fn baracuda_kernels_unary_powf_f16_strided_run(
numel: i64, rank: i32, shape: *const i32, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
exponent: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_unary_powf_f16_strided`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_unary_powf_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
exponent: f32,
) -> i32;
/// `unary_powf`, bf16, contig.
pub fn baracuda_kernels_unary_powf_bf16_run(
numel: i64, x: *const c_void, y: *mut c_void,
exponent: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_unary_powf_bf16_can_implement` (baracuda kernels unary powf bf16 can implement).
pub fn baracuda_kernels_unary_powf_bf16_can_implement(
numel: i64, x: *const c_void, y: *const c_void,
) -> i32;
/// `baracuda_kernels_unary_powf_bf16_strided_run` (baracuda kernels unary powf bf16 strided run).
pub fn baracuda_kernels_unary_powf_bf16_strided_run(
numel: i64, rank: i32, shape: *const i32, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
exponent: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_unary_powf_bf16_strided`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_unary_powf_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
exponent: f32,
) -> i32;
/// `unary_powf`, f64, contig. `pow` (libdevice) is full-double
/// precision; the f32 exponent is widened once at kernel entry.
pub fn baracuda_kernels_unary_powf_f64_run(
numel: i64, x: *const c_void, y: *mut c_void,
exponent: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_unary_powf_f64_can_implement` (baracuda kernels unary powf f64 can implement).
pub fn baracuda_kernels_unary_powf_f64_can_implement(
numel: i64, x: *const c_void, y: *const c_void,
) -> i32;
/// `baracuda_kernels_unary_powf_f64_strided_run` (baracuda kernels unary powf f64 strided run).
pub fn baracuda_kernels_unary_powf_f64_strided_run(
numel: i64, rank: i32, shape: *const i32, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
exponent: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_unary_powf_f64_strided`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_unary_powf_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
exponent: f32,
) -> i32;
// ---- Phase 31 — Step (Heaviside) ----
//
// `y = (x > 0) ? 1 : 0`. NaN → 0 (NaN > 0 is false). Bare unary
// ABI — same shape as `unary_relu_*_run`.
/// `unary_step`, f32, contig.
pub fn baracuda_kernels_unary_step_f32_run(
numel: i64, x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_unary_step_f32_can_implement` (baracuda kernels unary step f32 can implement).
pub fn baracuda_kernels_unary_step_f32_can_implement(
numel: i64, x: *const c_void, y: *const c_void,
) -> i32;
/// `baracuda_kernels_unary_step_f32_strided_run` (baracuda kernels unary step f32 strided run).
pub fn baracuda_kernels_unary_step_f32_strided_run(
numel: i64, rank: i32, shape: *const i32, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_step_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_step_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `unary_step`, f16, contig.
pub fn baracuda_kernels_unary_step_f16_run(
numel: i64, x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_unary_step_f16_can_implement` (baracuda kernels unary step f16 can implement).
pub fn baracuda_kernels_unary_step_f16_can_implement(
numel: i64, x: *const c_void, y: *const c_void,
) -> i32;
/// `baracuda_kernels_unary_step_f16_strided_run` (baracuda kernels unary step f16 strided run).
pub fn baracuda_kernels_unary_step_f16_strided_run(
numel: i64, rank: i32, shape: *const i32, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_step_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_step_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `unary_step`, bf16, contig.
pub fn baracuda_kernels_unary_step_bf16_run(
numel: i64, x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_unary_step_bf16_can_implement` (baracuda kernels unary step bf16 can implement).
pub fn baracuda_kernels_unary_step_bf16_can_implement(
numel: i64, x: *const c_void, y: *const c_void,
) -> i32;
/// `baracuda_kernels_unary_step_bf16_strided_run` (baracuda kernels unary step bf16 strided run).
pub fn baracuda_kernels_unary_step_bf16_strided_run(
numel: i64, rank: i32, shape: *const i32, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_step_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_step_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `unary_step`, f64, contig.
pub fn baracuda_kernels_unary_step_f64_run(
numel: i64, x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_unary_step_f64_can_implement` (baracuda kernels unary step f64 can implement).
pub fn baracuda_kernels_unary_step_f64_can_implement(
numel: i64, x: *const c_void, y: *const c_void,
) -> i32;
/// `baracuda_kernels_unary_step_f64_strided_run` (baracuda kernels unary step f64 strided run).
pub fn baracuda_kernels_unary_step_f64_strided_run(
numel: i64, rank: i32, shape: *const i32, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_step_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_step_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
// ---- Phase 31 — GeluErf (exact erf-based GELU) ----
//
// `y = 0.5 * x * (1 + erf(x / sqrt(2)))`. Distinct symbol name
// for the exact erf-based GELU formula so Fuel's storage.rs can
// route precisely (vs the tanh-approximation `unary_gelu_tanh_*`).
//
// NOTE: `unary_gelu_*` ALSO implements the erf-based formula
// today — both symbols coexist with bit-identical math. The
// duplication is intentional (Fuel ask); a future consolidation
// may make `unary_gelu_*` an alias of `unary_gelu_erf_*`.
/// `unary_gelu_erf`, f32, contig.
///
/// Bit-identical alias of `unary_gelu_*` (ERF-EXACT flavor) —
/// added in Phase 31 so consumers can bind the flavor
/// unambiguously by name.
pub fn baracuda_kernels_unary_gelu_erf_f32_run(
numel: i64, x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_unary_gelu_erf_f32_can_implement` (baracuda kernels unary gelu erf f32 can implement).
pub fn baracuda_kernels_unary_gelu_erf_f32_can_implement(
numel: i64, x: *const c_void, y: *const c_void,
) -> i32;
/// `baracuda_kernels_unary_gelu_erf_f32_strided_run` (baracuda kernels unary gelu erf f32 strided run).
pub fn baracuda_kernels_unary_gelu_erf_f32_strided_run(
numel: i64, rank: i32, shape: *const i32, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_gelu_erf_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_gelu_erf_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `unary_gelu_erf`, f16, contig.
///
/// Bit-identical alias of `unary_gelu_*` (ERF-EXACT flavor) —
/// added in Phase 31 so consumers can bind the flavor
/// unambiguously by name.
pub fn baracuda_kernels_unary_gelu_erf_f16_run(
numel: i64, x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_unary_gelu_erf_f16_can_implement` (baracuda kernels unary gelu erf f16 can implement).
pub fn baracuda_kernels_unary_gelu_erf_f16_can_implement(
numel: i64, x: *const c_void, y: *const c_void,
) -> i32;
/// `baracuda_kernels_unary_gelu_erf_f16_strided_run` (baracuda kernels unary gelu erf f16 strided run).
pub fn baracuda_kernels_unary_gelu_erf_f16_strided_run(
numel: i64, rank: i32, shape: *const i32, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_gelu_erf_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_gelu_erf_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `unary_gelu_erf`, bf16, contig.
///
/// Bit-identical alias of `unary_gelu_*` (ERF-EXACT flavor) —
/// added in Phase 31 so consumers can bind the flavor
/// unambiguously by name.
pub fn baracuda_kernels_unary_gelu_erf_bf16_run(
numel: i64, x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_unary_gelu_erf_bf16_can_implement` (baracuda kernels unary gelu erf bf16 can implement).
pub fn baracuda_kernels_unary_gelu_erf_bf16_can_implement(
numel: i64, x: *const c_void, y: *const c_void,
) -> i32;
/// `baracuda_kernels_unary_gelu_erf_bf16_strided_run` (baracuda kernels unary gelu erf bf16 strided run).
pub fn baracuda_kernels_unary_gelu_erf_bf16_strided_run(
numel: i64, rank: i32, shape: *const i32, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_gelu_erf_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_gelu_erf_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// `unary_gelu_erf`, f64, contig.
///
/// Bit-identical alias of `unary_gelu_*` (ERF-EXACT flavor) —
/// added in Phase 31 so consumers can bind the flavor
/// unambiguously by name.
pub fn baracuda_kernels_unary_gelu_erf_f64_run(
numel: i64, x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_unary_gelu_erf_f64_can_implement` (baracuda kernels unary gelu erf f64 can implement).
pub fn baracuda_kernels_unary_gelu_erf_f64_can_implement(
numel: i64, x: *const c_void, y: *const c_void,
) -> i32;
/// `baracuda_kernels_unary_gelu_erf_f64_strided_run` (baracuda kernels unary gelu erf f64 strided run).
pub fn baracuda_kernels_unary_gelu_erf_f64_strided_run(
numel: i64, rank: i32, shape: *const i32, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_gelu_erf_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_gelu_erf_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
// ---- Phase 31 — ReduceSumTo / ReduceMaxTo (broadcast-reverse) ----
//
// The autograd primitives that undo a forward `BroadcastTo`. For
// each output cell, sum (or max) every input cell that broadcasts
// TO it. One thread per output cell — no atomics needed; the
// broadcast pattern is computed from the shape arrays per thread.
//
// ABI:
// * `src` / `dst` are device pointers in `T`.
// * `input_shape` (host, length `rank_in`) — the source extents.
// * `input_stride` (host, length `rank_in`, i64) — the source
// strides; may be non-contiguous.
// * `rank_in` ≤ 8 (MAX_RANK). Caller pads `output_shape` on the
// LEFT with 1s to match `rank_in`.
// * `output_shape` (host, length `rank_in`) — the target extents;
// dst is laid out contiguously over this shape.
//
// Output pre-fill is NOT required (the per-cell kernel writes the
// identity directly if the broadcast set is empty).
/// `reduce_sum_to`, f32. Broadcast-reverse Σ. Phase 31.
pub fn baracuda_kernels_reduce_sum_to_f32_run(
src: *const c_void, dst: *mut c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_reduce_sum_to_f32_can_implement` (baracuda kernels reduce sum to f32 can implement).
pub fn baracuda_kernels_reduce_sum_to_f32_can_implement(
src: *const c_void, dst: *const c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
) -> i32;
/// `reduce_sum_to`, f64.
pub fn baracuda_kernels_reduce_sum_to_f64_run(
src: *const c_void, dst: *mut c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_reduce_sum_to_f64_can_implement` (baracuda kernels reduce sum to f64 can implement).
pub fn baracuda_kernels_reduce_sum_to_f64_can_implement(
src: *const c_void, dst: *const c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
) -> i32;
/// `reduce_sum_to`, f16. Accumulator widens to f32 per the rest
/// of the family's convention.
pub fn baracuda_kernels_reduce_sum_to_f16_run(
src: *const c_void, dst: *mut c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_reduce_sum_to_f16_can_implement` (baracuda kernels reduce sum to f16 can implement).
pub fn baracuda_kernels_reduce_sum_to_f16_can_implement(
src: *const c_void, dst: *const c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
) -> i32;
/// `reduce_sum_to`, bf16. Accumulator widens to f32.
pub fn baracuda_kernels_reduce_sum_to_bf16_run(
src: *const c_void, dst: *mut c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_reduce_sum_to_bf16_can_implement` (baracuda kernels reduce sum to bf16 can implement).
pub fn baracuda_kernels_reduce_sum_to_bf16_can_implement(
src: *const c_void, dst: *const c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
) -> i32;
/// `reduce_max_to`, f32. Identity is `-FLT_MAX` when the broadcast
/// set is empty.
pub fn baracuda_kernels_reduce_max_to_f32_run(
src: *const c_void, dst: *mut c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_reduce_max_to_f32_can_implement` (baracuda kernels reduce max to f32 can implement).
pub fn baracuda_kernels_reduce_max_to_f32_can_implement(
src: *const c_void, dst: *const c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
) -> i32;
/// `reduce_max_to`, f64. Identity is `-DBL_MAX`.
pub fn baracuda_kernels_reduce_max_to_f64_run(
src: *const c_void, dst: *mut c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_reduce_max_to_f64_can_implement` (baracuda kernels reduce max to f64 can implement).
pub fn baracuda_kernels_reduce_max_to_f64_can_implement(
src: *const c_void, dst: *const c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
) -> i32;
/// `reduce_max_to`, f16. Identity is `-FLT_MAX` in f32 accumulator
/// space, narrowed back to f16 on store.
pub fn baracuda_kernels_reduce_max_to_f16_run(
src: *const c_void, dst: *mut c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_reduce_max_to_f16_can_implement` (baracuda kernels reduce max to f16 can implement).
pub fn baracuda_kernels_reduce_max_to_f16_can_implement(
src: *const c_void, dst: *const c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
) -> i32;
/// `reduce_max_to`, bf16.
pub fn baracuda_kernels_reduce_max_to_bf16_run(
src: *const c_void, dst: *mut c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_reduce_max_to_bf16_can_implement` (baracuda kernels reduce max to bf16 can implement).
pub fn baracuda_kernels_reduce_max_to_bf16_can_implement(
src: *const c_void, dst: *const c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
) -> i32;
// ====================================================================
// Phase 37 Gap 1a — broadcast-reverse MIN and PROD.
//
// Same shape and contract as `reduce_sum_to` / `reduce_max_to`. The
// empty-broadcast identity is `+FLT_MAX` (Min) and `1` (Prod). Half-
// precision storage of the Min identity narrows to `+inf` on store
// (matches the per-axis `MinReduce`'s `+INFINITY` initial value).
// ====================================================================
/// `reduce_min_to`, f32. Identity is `+FLT_MAX` when the broadcast
/// set is empty.
pub fn baracuda_kernels_reduce_min_to_f32_run(
src: *const c_void, dst: *mut c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_reduce_min_to_f32_can_implement` (baracuda kernels reduce min to f32 can implement).
pub fn baracuda_kernels_reduce_min_to_f32_can_implement(
src: *const c_void, dst: *const c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
) -> i32;
/// `reduce_min_to`, f64. Identity is `+DBL_MAX`.
pub fn baracuda_kernels_reduce_min_to_f64_run(
src: *const c_void, dst: *mut c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_reduce_min_to_f64_can_implement` (baracuda kernels reduce min to f64 can implement).
pub fn baracuda_kernels_reduce_min_to_f64_can_implement(
src: *const c_void, dst: *const c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
) -> i32;
/// `reduce_min_to`, f16. Accumulator widens to f32; identity is
/// `+FLT_MAX` in f32 accumulator space, narrowing to `+inf` on
/// store.
pub fn baracuda_kernels_reduce_min_to_f16_run(
src: *const c_void, dst: *mut c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_reduce_min_to_f16_can_implement` (baracuda kernels reduce min to f16 can implement).
pub fn baracuda_kernels_reduce_min_to_f16_can_implement(
src: *const c_void, dst: *const c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
) -> i32;
/// `reduce_min_to`, bf16.
pub fn baracuda_kernels_reduce_min_to_bf16_run(
src: *const c_void, dst: *mut c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_reduce_min_to_bf16_can_implement` (baracuda kernels reduce min to bf16 can implement).
pub fn baracuda_kernels_reduce_min_to_bf16_can_implement(
src: *const c_void, dst: *const c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
) -> i32;
/// `reduce_prod_to`, f32. Identity is `1` (multiplicative). Half
/// dtypes accumulate in f32 then narrow on store.
pub fn baracuda_kernels_reduce_prod_to_f32_run(
src: *const c_void, dst: *mut c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_reduce_prod_to_f32_can_implement` (baracuda kernels reduce prod to f32 can implement).
pub fn baracuda_kernels_reduce_prod_to_f32_can_implement(
src: *const c_void, dst: *const c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
) -> i32;
/// `reduce_prod_to`, f64.
pub fn baracuda_kernels_reduce_prod_to_f64_run(
src: *const c_void, dst: *mut c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_reduce_prod_to_f64_can_implement` (baracuda kernels reduce prod to f64 can implement).
pub fn baracuda_kernels_reduce_prod_to_f64_can_implement(
src: *const c_void, dst: *const c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
) -> i32;
/// `reduce_prod_to`, f16. Cumulative product overflows fast in
/// half-precision; callers should keep values close to 1.
pub fn baracuda_kernels_reduce_prod_to_f16_run(
src: *const c_void, dst: *mut c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_reduce_prod_to_f16_can_implement` (baracuda kernels reduce prod to f16 can implement).
pub fn baracuda_kernels_reduce_prod_to_f16_can_implement(
src: *const c_void, dst: *const c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
) -> i32;
/// `reduce_prod_to`, bf16.
pub fn baracuda_kernels_reduce_prod_to_bf16_run(
src: *const c_void, dst: *mut c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_reduce_prod_to_bf16_can_implement` (baracuda kernels reduce prod to bf16 can implement).
pub fn baracuda_kernels_reduce_prod_to_bf16_can_implement(
src: *const c_void, dst: *const c_void,
input_shape: *const i32, input_stride: *const i64,
rank_in: i32,
output_shape: *const i32,
) -> i32;
// ---- Hardshrink (λ=0.5) ----
/// Unary elementwise `hardshrink` (λ=0.5), f32, contig.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-relu trailblazer.
pub fn baracuda_kernels_unary_hardshrink_f32_run(
numel: i64, x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardshrink_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardshrink_f32_can_implement(
numel: i64, x: *const c_void, y: *const c_void,
) -> i32;
/// Unary elementwise `hardshrink` (λ=0.5), f32, strided.
///
/// # Safety
/// Same contract as the unary-relu strided launcher.
pub fn baracuda_kernels_unary_hardshrink_f32_strided_run(
numel: i64, rank: i32, shape: *const i32, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardshrink_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardshrink_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `hardshrink` (λ=0.5), f16, contig.
///
/// # Safety
/// `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_hardshrink_f16_run(
numel: i64, x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardshrink_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardshrink_f16_can_implement(
numel: i64, x: *const c_void, y: *const c_void,
) -> i32;
/// Unary elementwise `hardshrink` (λ=0.5), f16, strided.
///
/// # Safety
/// `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_hardshrink_f16_strided_run(
numel: i64, rank: i32, shape: *const i32, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardshrink_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardshrink_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `hardshrink` (λ=0.5), bf16, contig.
///
/// # Safety
/// `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_hardshrink_bf16_run(
numel: i64, x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardshrink_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardshrink_bf16_can_implement(
numel: i64, x: *const c_void, y: *const c_void,
) -> i32;
/// Unary elementwise `hardshrink` (λ=0.5), bf16, strided.
///
/// # Safety
/// `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_hardshrink_bf16_strided_run(
numel: i64, rank: i32, shape: *const i32, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardshrink_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardshrink_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `hardshrink` (λ=0.5), f64, contig.
///
/// # Safety
/// `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_hardshrink_f64_run(
numel: i64, x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardshrink_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardshrink_f64_can_implement(
numel: i64, x: *const c_void, y: *const c_void,
) -> i32;
/// Unary elementwise `hardshrink` (λ=0.5), f64, strided.
///
/// # Safety
/// `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_hardshrink_f64_strided_run(
numel: i64, rank: i32, shape: *const i32, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_hardshrink_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_hardshrink_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
// ---- Softshrink (λ=0.5) ----
/// Unary elementwise `softshrink` (λ=0.5), f32, contig.
///
/// # Safety
/// Same device-pointer / stream contract as the unary-relu trailblazer.
pub fn baracuda_kernels_unary_softshrink_f32_run(
numel: i64, x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softshrink_f32`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_softshrink_f32_can_implement(
numel: i64, x: *const c_void, y: *const c_void,
) -> i32;
/// Unary elementwise `softshrink` (λ=0.5), f32, strided.
///
/// # Safety
/// Same contract as the unary-relu strided launcher.
pub fn baracuda_kernels_unary_softshrink_f32_strided_run(
numel: i64, rank: i32, shape: *const i32, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softshrink_f32_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_softshrink_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `softshrink` (λ=0.5), f16, contig.
///
/// # Safety
/// `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_softshrink_f16_run(
numel: i64, x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softshrink_f16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_softshrink_f16_can_implement(
numel: i64, x: *const c_void, y: *const c_void,
) -> i32;
/// Unary elementwise `softshrink` (λ=0.5), f16, strided.
///
/// # Safety
/// `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_softshrink_f16_strided_run(
numel: i64, rank: i32, shape: *const i32, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softshrink_f16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_softshrink_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `softshrink` (λ=0.5), bf16, contig.
///
/// # Safety
/// `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_softshrink_bf16_run(
numel: i64, x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softshrink_bf16`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_softshrink_bf16_can_implement(
numel: i64, x: *const c_void, y: *const c_void,
) -> i32;
/// Unary elementwise `softshrink` (λ=0.5), bf16, strided.
///
/// # Safety
/// `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_softshrink_bf16_strided_run(
numel: i64, rank: i32, shape: *const i32, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softshrink_bf16_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_softshrink_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Unary elementwise `softshrink` (λ=0.5), f64, contig.
///
/// # Safety
/// `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_softshrink_f64_run(
numel: i64, x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softshrink_f64`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_softshrink_f64_can_implement(
numel: i64, x: *const c_void, y: *const c_void,
) -> i32;
/// Unary elementwise `softshrink` (λ=0.5), f64, strided.
///
/// # Safety
/// `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_softshrink_f64_strided_run(
numel: i64, rank: i32, shape: *const i32, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_softshrink_f64_strided`.
///
/// # Safety
/// Host-side checks only.
pub fn baracuda_kernels_unary_softshrink_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
// =========================================================================
// Phase 3 Category C′ — gated activations (forward + backward).
//
// ABI shape: one thread per OUTPUT cell. The input is split along
// `split_dim` into two halves `(a, b)`; output `y = a · gate(b)`
// has shape `input_shape` with `input_shape[split_dim]` halved.
// `x_half_offset` is `(input_shape[split_dim] / 2) · stride_x[split_dim]`
// — the element-offset between the a-half cell and the b-half cell
// for a given output coord. `dx_half_offset` is the same for `dx`
// (which is contig over `input_shape`).
// =========================================================================
/// SwiGLU forward, f32. `y = a · b · sigmoid(b)`.
///
/// # Safety
/// `x` / `y` point to `float` storage.
pub fn baracuda_kernels_gated_swiglu_f32_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_swiglu_f32_can_implement` (baracuda kernels gated swiglu f32 can implement).
pub fn baracuda_kernels_gated_swiglu_f32_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64,
stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// SwiGLU forward, f16.
pub fn baracuda_kernels_gated_swiglu_f16_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_swiglu_f16_can_implement` (baracuda kernels gated swiglu f16 can implement).
pub fn baracuda_kernels_gated_swiglu_f16_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64,
stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// SwiGLU forward, bf16.
pub fn baracuda_kernels_gated_swiglu_bf16_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_swiglu_bf16_can_implement` (baracuda kernels gated swiglu bf16 can implement).
pub fn baracuda_kernels_gated_swiglu_bf16_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64,
stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// SwiGLU forward, f64.
pub fn baracuda_kernels_gated_swiglu_f64_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_swiglu_f64_can_implement` (baracuda kernels gated swiglu f64 can implement).
pub fn baracuda_kernels_gated_swiglu_f64_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64,
stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// SwiGLU backward, f32. `da = dy·silu(b)`, `db = dy·a·silu'(b)`.
///
/// # Safety
/// `x` / `dy` / `dx` point to `float` storage.
pub fn baracuda_kernels_gated_swiglu_backward_f32_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_swiglu_backward_f32_can_implement` (baracuda kernels gated swiglu backward f32 can implement).
pub fn baracuda_kernels_gated_swiglu_backward_f32_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *const c_void,
) -> i32;
/// SwiGLU backward, f16.
pub fn baracuda_kernels_gated_swiglu_backward_f16_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_swiglu_backward_f16_can_implement` (baracuda kernels gated swiglu backward f16 can implement).
pub fn baracuda_kernels_gated_swiglu_backward_f16_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *const c_void,
) -> i32;
/// SwiGLU backward, bf16.
pub fn baracuda_kernels_gated_swiglu_backward_bf16_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_swiglu_backward_bf16_can_implement` (baracuda kernels gated swiglu backward bf16 can implement).
pub fn baracuda_kernels_gated_swiglu_backward_bf16_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *const c_void,
) -> i32;
/// SwiGLU backward, f64.
pub fn baracuda_kernels_gated_swiglu_backward_f64_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_swiglu_backward_f64_can_implement` (baracuda kernels gated swiglu backward f64 can implement).
pub fn baracuda_kernels_gated_swiglu_backward_f64_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *const c_void,
) -> i32;
/// GLU forward, f32. `y = a · sigmoid(b)`.
///
/// # Safety
/// `x` / `y` point to `float` storage.
pub fn baracuda_kernels_gated_glu_f32_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_glu_f32_can_implement` (baracuda kernels gated glu f32 can implement).
pub fn baracuda_kernels_gated_glu_f32_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64,
stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// GLU forward, f16.
pub fn baracuda_kernels_gated_glu_f16_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_glu_f16_can_implement` (baracuda kernels gated glu f16 can implement).
pub fn baracuda_kernels_gated_glu_f16_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64,
stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// GLU forward, bf16.
pub fn baracuda_kernels_gated_glu_bf16_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_glu_bf16_can_implement` (baracuda kernels gated glu bf16 can implement).
pub fn baracuda_kernels_gated_glu_bf16_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64,
stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// GLU forward, f64.
pub fn baracuda_kernels_gated_glu_f64_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_glu_f64_can_implement` (baracuda kernels gated glu f64 can implement).
pub fn baracuda_kernels_gated_glu_f64_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64,
stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// GLU backward, f32. `da = dy·sigmoid(b)`, `db = dy·a·sigmoid(b)·(1-sigmoid(b))`.
///
/// # Safety
/// `x` / `dy` / `dx` point to `float` storage.
pub fn baracuda_kernels_gated_glu_backward_f32_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_glu_backward_f32_can_implement` (baracuda kernels gated glu backward f32 can implement).
pub fn baracuda_kernels_gated_glu_backward_f32_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *const c_void,
) -> i32;
/// GLU backward, f16.
pub fn baracuda_kernels_gated_glu_backward_f16_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_glu_backward_f16_can_implement` (baracuda kernels gated glu backward f16 can implement).
pub fn baracuda_kernels_gated_glu_backward_f16_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *const c_void,
) -> i32;
/// GLU backward, bf16.
pub fn baracuda_kernels_gated_glu_backward_bf16_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_glu_backward_bf16_can_implement` (baracuda kernels gated glu backward bf16 can implement).
pub fn baracuda_kernels_gated_glu_backward_bf16_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *const c_void,
) -> i32;
/// GLU backward, f64.
pub fn baracuda_kernels_gated_glu_backward_f64_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_glu_backward_f64_can_implement` (baracuda kernels gated glu backward f64 can implement).
pub fn baracuda_kernels_gated_glu_backward_f64_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *const c_void,
) -> i32;
/// ReGLU forward, f32. `y = a · relu(b) = a · max(b, 0)`.
///
/// # Safety
/// `x` / `y` point to `float` storage.
pub fn baracuda_kernels_gated_reglu_f32_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_reglu_f32_can_implement` (baracuda kernels gated reglu f32 can implement).
pub fn baracuda_kernels_gated_reglu_f32_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64,
stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// ReGLU forward, f16.
pub fn baracuda_kernels_gated_reglu_f16_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_reglu_f16_can_implement` (baracuda kernels gated reglu f16 can implement).
pub fn baracuda_kernels_gated_reglu_f16_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64,
stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// ReGLU forward, bf16.
pub fn baracuda_kernels_gated_reglu_bf16_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_reglu_bf16_can_implement` (baracuda kernels gated reglu bf16 can implement).
pub fn baracuda_kernels_gated_reglu_bf16_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64,
stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// ReGLU forward, f64.
pub fn baracuda_kernels_gated_reglu_f64_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_reglu_f64_can_implement` (baracuda kernels gated reglu f64 can implement).
pub fn baracuda_kernels_gated_reglu_f64_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64,
stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// ReGLU backward, f32. `da = (b>0)?dy·b:0`, `db = (b>0)?dy·a:0`.
///
/// # Safety
/// `x` / `dy` / `dx` point to `float` storage.
pub fn baracuda_kernels_gated_reglu_backward_f32_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_reglu_backward_f32_can_implement` (baracuda kernels gated reglu backward f32 can implement).
pub fn baracuda_kernels_gated_reglu_backward_f32_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *const c_void,
) -> i32;
/// ReGLU backward, f16.
pub fn baracuda_kernels_gated_reglu_backward_f16_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_reglu_backward_f16_can_implement` (baracuda kernels gated reglu backward f16 can implement).
pub fn baracuda_kernels_gated_reglu_backward_f16_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *const c_void,
) -> i32;
/// ReGLU backward, bf16.
pub fn baracuda_kernels_gated_reglu_backward_bf16_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_reglu_backward_bf16_can_implement` (baracuda kernels gated reglu backward bf16 can implement).
pub fn baracuda_kernels_gated_reglu_backward_bf16_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *const c_void,
) -> i32;
/// ReGLU backward, f64.
pub fn baracuda_kernels_gated_reglu_backward_f64_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_reglu_backward_f64_can_implement` (baracuda kernels gated reglu backward f64 can implement).
pub fn baracuda_kernels_gated_reglu_backward_f64_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *const c_void,
) -> i32;
/// GeGLU forward, f32. `y = a · gelu(b)`, exact erf-based.
///
/// # Safety
/// `x` / `y` point to `float` storage.
pub fn baracuda_kernels_gated_geglu_f32_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_geglu_f32_can_implement` (baracuda kernels gated geglu f32 can implement).
pub fn baracuda_kernels_gated_geglu_f32_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64,
stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// GeGLU forward, f16.
pub fn baracuda_kernels_gated_geglu_f16_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_geglu_f16_can_implement` (baracuda kernels gated geglu f16 can implement).
pub fn baracuda_kernels_gated_geglu_f16_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64,
stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// GeGLU forward, bf16.
pub fn baracuda_kernels_gated_geglu_bf16_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_geglu_bf16_can_implement` (baracuda kernels gated geglu bf16 can implement).
pub fn baracuda_kernels_gated_geglu_bf16_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64,
stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// GeGLU forward, f64.
pub fn baracuda_kernels_gated_geglu_f64_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_geglu_f64_can_implement` (baracuda kernels gated geglu f64 can implement).
pub fn baracuda_kernels_gated_geglu_f64_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64,
stride_x: *const i64, stride_y: *const i64,
x: *const c_void, y: *const c_void,
) -> i32;
/// GeGLU backward, f32. `da = dy·gelu(b)`, `db = dy·a·gelu'(b)`.
///
/// # Safety
/// `x` / `dy` / `dx` point to `float` storage.
pub fn baracuda_kernels_gated_geglu_backward_f32_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_geglu_backward_f32_can_implement` (baracuda kernels gated geglu backward f32 can implement).
pub fn baracuda_kernels_gated_geglu_backward_f32_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *const c_void,
) -> i32;
/// GeGLU backward, f16.
pub fn baracuda_kernels_gated_geglu_backward_f16_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_geglu_backward_f16_can_implement` (baracuda kernels gated geglu backward f16 can implement).
pub fn baracuda_kernels_gated_geglu_backward_f16_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *const c_void,
) -> i32;
/// GeGLU backward, bf16.
pub fn baracuda_kernels_gated_geglu_backward_bf16_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_geglu_backward_bf16_can_implement` (baracuda kernels gated geglu backward bf16 can implement).
pub fn baracuda_kernels_gated_geglu_backward_bf16_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *const c_void,
) -> i32;
/// GeGLU backward, f64.
pub fn baracuda_kernels_gated_geglu_backward_f64_run(
output_numel: i64, rank: i32, output_shape: *const i32, split_dim: i32,
x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gated_geglu_backward_f64_can_implement` (baracuda kernels gated geglu backward f64 can implement).
pub fn baracuda_kernels_gated_geglu_backward_f64_can_implement(
output_numel: i64, rank: i32, output_shape: *const i32,
split_dim: i32, x_half_offset: i64, dx_half_offset: i64,
stride_x: *const i64, stride_dy: *const i64, stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *const c_void,
) -> i32;
}
// =============================================================================
// Ternary backward family — Phase 3 backward fanout (Milestone F)
// =============================================================================
//
// 4 ops × 4 FP dtypes = 16 launchers.
//
// Unscaled (Fma, Clamp) — 7-pointer ABI: dy, a, b, c, da, db, dc.
// Scaled (Addcmul, Addcdiv) — same 7 pointers + `float scale` between
// `dc` and the workspace pointer, mirroring the FW scaled-ternary ABI.
//
// All four saved inputs are read every cell regardless of whether the
// op's gradient references them — see the .cu file comments for why
// (uniform ABI across the family; one extra coalesced load is cheap).
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
// ---- Fma backward (unscaled) ----
/// Fma backward, f32. Writes `da = dy*b`, `db = dy*a`, `dc = dy`.
pub fn baracuda_kernels_ternary_fma_backward_f32_run(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *mut c_void, db: *mut c_void, dc: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check for `ternary_fma_backward_f32`.
pub fn baracuda_kernels_ternary_fma_backward_f32_can_implement(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *const c_void, db: *const c_void, dc: *const c_void,
) -> i32;
/// Fma backward, f16.
pub fn baracuda_kernels_ternary_fma_backward_f16_run(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *mut c_void, db: *mut c_void, dc: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check for `ternary_fma_backward_f16`.
pub fn baracuda_kernels_ternary_fma_backward_f16_can_implement(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *const c_void, db: *const c_void, dc: *const c_void,
) -> i32;
/// Fma backward, bf16.
pub fn baracuda_kernels_ternary_fma_backward_bf16_run(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *mut c_void, db: *mut c_void, dc: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check for `ternary_fma_backward_bf16`.
pub fn baracuda_kernels_ternary_fma_backward_bf16_can_implement(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *const c_void, db: *const c_void, dc: *const c_void,
) -> i32;
/// Fma backward, f64.
pub fn baracuda_kernels_ternary_fma_backward_f64_run(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *mut c_void, db: *mut c_void, dc: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check for `ternary_fma_backward_f64`.
pub fn baracuda_kernels_ternary_fma_backward_f64_can_implement(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *const c_void, db: *const c_void, dc: *const c_void,
) -> i32;
// ---- Clamp backward (unscaled, mask × dy) ----
/// Clamp backward, f32. Writes mask × dy per axis (a/b/c).
pub fn baracuda_kernels_ternary_clamp_backward_f32_run(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *mut c_void, db: *mut c_void, dc: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check for `ternary_clamp_backward_f32`.
pub fn baracuda_kernels_ternary_clamp_backward_f32_can_implement(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *const c_void, db: *const c_void, dc: *const c_void,
) -> i32;
/// Clamp backward, f16.
pub fn baracuda_kernels_ternary_clamp_backward_f16_run(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *mut c_void, db: *mut c_void, dc: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check for `ternary_clamp_backward_f16`.
pub fn baracuda_kernels_ternary_clamp_backward_f16_can_implement(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *const c_void, db: *const c_void, dc: *const c_void,
) -> i32;
/// Clamp backward, bf16.
pub fn baracuda_kernels_ternary_clamp_backward_bf16_run(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *mut c_void, db: *mut c_void, dc: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check for `ternary_clamp_backward_bf16`.
pub fn baracuda_kernels_ternary_clamp_backward_bf16_can_implement(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *const c_void, db: *const c_void, dc: *const c_void,
) -> i32;
/// Clamp backward, f64.
pub fn baracuda_kernels_ternary_clamp_backward_f64_run(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *mut c_void, db: *mut c_void, dc: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check for `ternary_clamp_backward_f64`.
pub fn baracuda_kernels_ternary_clamp_backward_f64_can_implement(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *const c_void, db: *const c_void, dc: *const c_void,
) -> i32;
// ---- Addcmul backward (scaled) ----
/// Addcmul backward, f32. Reads `desc.scale`.
/// Writes `da = dy`, `db = dy*scale*c`, `dc = dy*scale*b`.
pub fn baracuda_kernels_ternary_addcmul_backward_f32_run(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *mut c_void, db: *mut c_void, dc: *mut c_void,
scale: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check for `ternary_addcmul_backward_f32`.
pub fn baracuda_kernels_ternary_addcmul_backward_f32_can_implement(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *const c_void, db: *const c_void, dc: *const c_void,
scale: f32,
) -> i32;
/// Addcmul backward, f16.
pub fn baracuda_kernels_ternary_addcmul_backward_f16_run(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *mut c_void, db: *mut c_void, dc: *mut c_void,
scale: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check for `ternary_addcmul_backward_f16`.
pub fn baracuda_kernels_ternary_addcmul_backward_f16_can_implement(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *const c_void, db: *const c_void, dc: *const c_void,
scale: f32,
) -> i32;
/// Addcmul backward, bf16.
pub fn baracuda_kernels_ternary_addcmul_backward_bf16_run(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *mut c_void, db: *mut c_void, dc: *mut c_void,
scale: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check for `ternary_addcmul_backward_bf16`.
pub fn baracuda_kernels_ternary_addcmul_backward_bf16_can_implement(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *const c_void, db: *const c_void, dc: *const c_void,
scale: f32,
) -> i32;
/// Addcmul backward, f64.
pub fn baracuda_kernels_ternary_addcmul_backward_f64_run(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *mut c_void, db: *mut c_void, dc: *mut c_void,
scale: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check for `ternary_addcmul_backward_f64`.
pub fn baracuda_kernels_ternary_addcmul_backward_f64_can_implement(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *const c_void, db: *const c_void, dc: *const c_void,
scale: f32,
) -> i32;
// ---- Addcdiv backward (scaled) ----
/// Addcdiv backward, f32. Reads `desc.scale`.
/// Writes `da = dy`, `db = dy*scale/c`, `dc = -dy*scale*b/c²`.
pub fn baracuda_kernels_ternary_addcdiv_backward_f32_run(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *mut c_void, db: *mut c_void, dc: *mut c_void,
scale: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check for `ternary_addcdiv_backward_f32`.
pub fn baracuda_kernels_ternary_addcdiv_backward_f32_can_implement(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *const c_void, db: *const c_void, dc: *const c_void,
scale: f32,
) -> i32;
/// Addcdiv backward, f16.
pub fn baracuda_kernels_ternary_addcdiv_backward_f16_run(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *mut c_void, db: *mut c_void, dc: *mut c_void,
scale: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check for `ternary_addcdiv_backward_f16`.
pub fn baracuda_kernels_ternary_addcdiv_backward_f16_can_implement(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *const c_void, db: *const c_void, dc: *const c_void,
scale: f32,
) -> i32;
/// Addcdiv backward, bf16.
pub fn baracuda_kernels_ternary_addcdiv_backward_bf16_run(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *mut c_void, db: *mut c_void, dc: *mut c_void,
scale: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check for `ternary_addcdiv_backward_bf16`.
pub fn baracuda_kernels_ternary_addcdiv_backward_bf16_can_implement(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *const c_void, db: *const c_void, dc: *const c_void,
scale: f32,
) -> i32;
/// Addcdiv backward, f64.
pub fn baracuda_kernels_ternary_addcdiv_backward_f64_run(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *mut c_void, db: *mut c_void, dc: *mut c_void,
scale: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch check for `ternary_addcdiv_backward_f64`.
pub fn baracuda_kernels_ternary_addcdiv_backward_f64_can_implement(
numel: i64,
dy: *const c_void, a: *const c_void, b: *const c_void, c: *const c_void,
da: *const c_void, db: *const c_void, dc: *const c_void,
scale: f32,
) -> i32;
}
// ----------------------------------------------------------------------------
// Parameterized unary / binary plan families — Phase 3 deferred ops.
//
// New ABI shape vs the plain unary / binary launchers: f32 scalar
// parameters threaded by value through the launcher signature.
// Unary param FW : `(numel, x, y, p0, p1, ws, ws_bytes, stream)`
// Unary param BW : `(numel, dy, x, dx, p0, p1, ws, ws_bytes, stream)`
// Binary param FW: `(numel, a, b, y, p, ws, ws_bytes, stream)`
// Binary param BW: `(numel, dy, da, db, p, ws, ws_bytes, stream)`
//
// Today's wired ops:
// Threshold (2 params: t = p0, v = p1) — FW + BW × {f32, f16, bf16, f64}.
// Lerp (1 param : weight = p) — FW + BW × {f32, f16, bf16, f64}.
//
// Contig-only — no strided variant for the trailblazer.
// ----------------------------------------------------------------------------
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
// ---- Threshold FW (params: p0 = t, p1 = v) ----
/// Unary elementwise `threshold(x; t, v) = (x > t) ? x : v`, f32, contig.
///
/// # Safety
/// Same device-pointer / stream contract as the plain unary launchers.
/// `p0` carries the threshold `t`; `p1` carries the replacement value `v`.
pub fn baracuda_kernels_unary_threshold_f32_run(
numel: i64, x: *const c_void, y: *mut c_void,
p0: f32, p1: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_unary_threshold_f32`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_unary_threshold_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
p0: f32,
p1: f32,
) -> i32;
/// `threshold` FW, f16.
///
/// # Safety
/// `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_threshold_f16_run(
numel: i64, x: *const c_void, y: *mut c_void,
p0: f32, p1: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_unary_threshold_f16`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_unary_threshold_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
p0: f32,
p1: f32,
) -> i32;
/// `threshold` FW, bf16.
///
/// # Safety
/// `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_threshold_bf16_run(
numel: i64, x: *const c_void, y: *mut c_void,
p0: f32, p1: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_unary_threshold_bf16`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_unary_threshold_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
p0: f32,
p1: f32,
) -> i32;
/// `threshold` FW, f64. The f32 params widen to f64 losslessly.
///
/// # Safety
/// `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_threshold_f64_run(
numel: i64, x: *const c_void, y: *mut c_void,
p0: f32, p1: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_unary_threshold_f64`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_unary_threshold_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
p0: f32,
p1: f32,
) -> i32;
// ---- Threshold BW (saved-x; params: p0 = t, p1 = v unused) ----
/// `threshold` backward: `dx = (x > t) ? dy : 0`, f32. Saved-x.
///
/// # Safety
/// `dy`, `x`, `dx` device pointers; `p1` ignored by the kernel (kept on the
/// ABI for shape parity with FW).
pub fn baracuda_kernels_unary_threshold_backward_f32_run(
numel: i64, dy: *const c_void, x: *const c_void, dx: *mut c_void,
p0: f32, p1: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_threshold_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_threshold_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
x: *const c_void,
dx: *const c_void,
p0: f32,
p1: f32,
) -> i32;
/// `threshold` BW, f16.
///
/// # Safety
/// All tensor pointers reference `__half` storage.
pub fn baracuda_kernels_unary_threshold_backward_f16_run(
numel: i64, dy: *const c_void, x: *const c_void, dx: *mut c_void,
p0: f32, p1: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_threshold_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_threshold_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
x: *const c_void,
dx: *const c_void,
p0: f32,
p1: f32,
) -> i32;
/// `threshold` BW, bf16.
///
/// # Safety
/// All tensor pointers reference `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_threshold_backward_bf16_run(
numel: i64, dy: *const c_void, x: *const c_void, dx: *mut c_void,
p0: f32, p1: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_threshold_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_threshold_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
x: *const c_void,
dx: *const c_void,
p0: f32,
p1: f32,
) -> i32;
/// `threshold` BW, f64.
///
/// # Safety
/// All tensor pointers reference `double` storage.
pub fn baracuda_kernels_unary_threshold_backward_f64_run(
numel: i64, dy: *const c_void, x: *const c_void, dx: *mut c_void,
p0: f32, p1: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_threshold_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_threshold_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
x: *const c_void,
dx: *const c_void,
p0: f32,
p1: f32,
) -> i32;
// ---- PowI FW (param: p0 = n as float, p1 unused) ----
//
// Integer-exponent power: `y = x^n` via power-by-squaring. The
// exponent `n` is shipped as `p0` cast to `int` at the kernel
// boundary; reasonable |n| values (≤ 2^24) round-trip through f32
// exactly. `p1` is ignored — kept for ABI parity with the rest of
// the `unary_param_*` family.
/// Unary elementwise `powi(x; n) = x^n` (integer exponent), f32, contig.
///
/// # Safety
/// Same device-pointer / stream contract as the plain unary launchers.
/// `p0` carries the integer exponent `n` reinterpreted as f32; `p1` ignored.
pub fn baracuda_kernels_unary_powi_f32_run(
numel: i64, x: *const c_void, y: *mut c_void,
p0: f32, p1: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_unary_powi_f32`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_unary_powi_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
p0: f32,
p1: f32,
) -> i32;
/// `powi` FW, f16.
///
/// # Safety
/// `x` / `y` point to `__half` storage; product chain runs in f32.
pub fn baracuda_kernels_unary_powi_f16_run(
numel: i64, x: *const c_void, y: *mut c_void,
p0: f32, p1: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_unary_powi_f16`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_unary_powi_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
p0: f32,
p1: f32,
) -> i32;
/// `powi` FW, bf16.
///
/// # Safety
/// `x` / `y` point to `__nv_bfloat16` storage; product chain runs in f32.
pub fn baracuda_kernels_unary_powi_bf16_run(
numel: i64, x: *const c_void, y: *mut c_void,
p0: f32, p1: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_unary_powi_bf16`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_unary_powi_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
p0: f32,
p1: f32,
) -> i32;
/// `powi` FW, f64.
///
/// # Safety
/// `x` / `y` point to `double` storage; product chain runs in f64.
pub fn baracuda_kernels_unary_powi_f64_run(
numel: i64, x: *const c_void, y: *mut c_void,
p0: f32, p1: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_unary_powi_f64`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_unary_powi_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
p0: f32,
p1: f32,
) -> i32;
// ---- PowI BW (saved-x; param: p0 = n as float, p1 unused) ----
//
// `dx = n * x^(n-1) * dy`. Special-cased internally for `n == 0`
// (writes 0) and `n == 1` (writes `dy`).
/// `powi` backward: `dx = n · x^(n-1) · dy`, f32. Saved-x.
///
/// # Safety
/// `dy`, `x`, `dx` device pointers; `p0` is the integer exponent
/// reinterpreted as f32, `p1` ignored.
pub fn baracuda_kernels_unary_powi_backward_f32_run(
numel: i64, dy: *const c_void, x: *const c_void, dx: *mut c_void,
p0: f32, p1: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_powi_backward_f32`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_powi_backward_f32_can_implement(
numel: i64,
dy: *const c_void,
x: *const c_void,
dx: *const c_void,
p0: f32,
p1: f32,
) -> i32;
/// `powi` BW, f16.
///
/// # Safety
/// All tensor pointers reference `__half` storage.
pub fn baracuda_kernels_unary_powi_backward_f16_run(
numel: i64, dy: *const c_void, x: *const c_void, dx: *mut c_void,
p0: f32, p1: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_powi_backward_f16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_powi_backward_f16_can_implement(
numel: i64,
dy: *const c_void,
x: *const c_void,
dx: *const c_void,
p0: f32,
p1: f32,
) -> i32;
/// `powi` BW, bf16.
///
/// # Safety
/// All tensor pointers reference `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_powi_backward_bf16_run(
numel: i64, dy: *const c_void, x: *const c_void, dx: *mut c_void,
p0: f32, p1: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_powi_backward_bf16`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_powi_backward_bf16_can_implement(
numel: i64,
dy: *const c_void,
x: *const c_void,
dx: *const c_void,
p0: f32,
p1: f32,
) -> i32;
/// `powi` BW, f64.
///
/// # Safety
/// All tensor pointers reference `double` storage.
pub fn baracuda_kernels_unary_powi_backward_f64_run(
numel: i64, dy: *const c_void, x: *const c_void, dx: *mut c_void,
p0: f32, p1: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_powi_backward_f64`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur.
pub fn baracuda_kernels_unary_powi_backward_f64_can_implement(
numel: i64,
dy: *const c_void,
x: *const c_void,
dx: *const c_void,
p0: f32,
p1: f32,
) -> i32;
// ---- PowI strided FW (param: p0 = n as float, p1 unused) — Phase 14.2 ----
//
// Strided sibling of the contig PowI FW launchers above. One thread
// per output cell: each thread decomposes its linear index into a
// multi-coord against `shape`, then dots with `stride_x` / `stride_y`
// (signed i64) to derive the operand offsets. Same `p0` / `p1` ABI
// as the contig launchers — `p0` carries `n as f32`, `p1` ignored.
// `shape`, `stride_x`, `stride_y` are host pointers (read at launch).
/// `powi` FW, f32, strided.
///
/// # Safety
/// Same device-pointer / stream contract as the contig launcher;
/// `shape`, `stride_x`, `stride_y` point to host arrays of length
/// `rank` and remain valid through the call (read at launch, not
/// during kernel execution). `p0` carries the integer exponent.
pub fn baracuda_kernels_unary_powi_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void, y: *mut c_void,
p0: f32, p1: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_unary_powi_f32_strided`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_unary_powi_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
p0: f32,
p1: f32,
) -> i32;
/// `powi` FW, f16, strided.
///
/// # Safety
/// Same contract as `baracuda_kernels_unary_powi_f32_strided_run`;
/// `x` / `y` point to `__half` storage.
pub fn baracuda_kernels_unary_powi_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void, y: *mut c_void,
p0: f32, p1: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_unary_powi_f16_strided`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_unary_powi_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
p0: f32,
p1: f32,
) -> i32;
/// `powi` FW, bf16, strided.
///
/// # Safety
/// Same contract as `baracuda_kernels_unary_powi_f32_strided_run`;
/// `x` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_powi_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void, y: *mut c_void,
p0: f32, p1: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_unary_powi_bf16_strided`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_unary_powi_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
p0: f32,
p1: f32,
) -> i32;
/// `powi` FW, f64, strided.
///
/// # Safety
/// Same contract as `baracuda_kernels_unary_powi_f32_strided_run`;
/// `x` / `y` point to `double` storage.
pub fn baracuda_kernels_unary_powi_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void, y: *mut c_void,
p0: f32, p1: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_unary_powi_f64_strided`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_unary_powi_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
p0: f32,
p1: f32,
) -> i32;
// ---- PowI strided BW (saved-x; param: p0 = n as float, p1 unused) — Phase 14.2 ----
//
// Strided sibling of the contig PowI BW launchers. Carries three
// independent stride arrays — `stride_x`, `stride_dy`, `stride_dx`
// — so each of the three operands may be a different view.
/// `powi` BW, f32, strided.
///
/// # Safety
/// Same device-pointer / stream contract as the contig BW launcher.
/// `shape`, `stride_x`, `stride_dy`, `stride_dx` are host arrays of
/// length `rank`. `p0` is the integer exponent; `p1` ignored.
pub fn baracuda_kernels_unary_powi_backward_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_dy: *const i64,
stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *mut c_void,
p0: f32, p1: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_powi_backward_f32_strided`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur. `shape`, `stride_x`, `stride_dy`,
/// `stride_dx` are host-side arrays of length `rank`.
pub fn baracuda_kernels_unary_powi_backward_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_dy: *const i64,
stride_dx: *const i64,
x: *const c_void,
dy: *const c_void,
dx: *const c_void,
p0: f32,
p1: f32,
) -> i32;
/// `powi` BW, f16, strided.
///
/// # Safety
/// Same contract as `baracuda_kernels_unary_powi_backward_f32_strided_run`;
/// all tensor pointers reference `__half` storage.
pub fn baracuda_kernels_unary_powi_backward_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_dy: *const i64,
stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *mut c_void,
p0: f32, p1: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_powi_backward_f16_strided`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur. `shape`, `stride_x`, `stride_dy`,
/// `stride_dx` are host-side arrays of length `rank`.
pub fn baracuda_kernels_unary_powi_backward_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_dy: *const i64,
stride_dx: *const i64,
x: *const c_void,
dy: *const c_void,
dx: *const c_void,
p0: f32,
p1: f32,
) -> i32;
/// `powi` BW, bf16, strided.
///
/// # Safety
/// Same contract as `baracuda_kernels_unary_powi_backward_f32_strided_run`;
/// all tensor pointers reference `__nv_bfloat16` storage.
pub fn baracuda_kernels_unary_powi_backward_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_dy: *const i64,
stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *mut c_void,
p0: f32, p1: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_powi_backward_bf16_strided`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur. `shape`, `stride_x`, `stride_dy`,
/// `stride_dx` are host-side arrays of length `rank`.
pub fn baracuda_kernels_unary_powi_backward_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_dy: *const i64,
stride_dx: *const i64,
x: *const c_void,
dy: *const c_void,
dx: *const c_void,
p0: f32,
p1: f32,
) -> i32;
/// `powi` BW, f64, strided.
///
/// # Safety
/// Same contract as `baracuda_kernels_unary_powi_backward_f32_strided_run`;
/// all tensor pointers reference `double` storage.
pub fn baracuda_kernels_unary_powi_backward_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_dy: *const i64,
stride_dx: *const i64,
x: *const c_void, dy: *const c_void, dx: *mut c_void,
p0: f32, p1: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for `unary_powi_backward_f64_strided`.
/// Host-side validation; no kernel launch.
///
/// # Safety
/// Same pointer-validity contract as the matching `_run`; no
/// device dereferences occur. `shape`, `stride_x`, `stride_dy`,
/// `stride_dx` are host-side arrays of length `rank`.
pub fn baracuda_kernels_unary_powi_backward_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_dy: *const i64,
stride_dx: *const i64,
x: *const c_void,
dy: *const c_void,
dx: *const c_void,
p0: f32,
p1: f32,
) -> i32;
// ---- Lerp FW (param: weight) ----
/// Binary elementwise `lerp(a, b; weight) = a + weight·(b - a)`, f32, contig.
///
/// # Safety
/// Same device-pointer / stream contract as the plain binary launchers.
/// `p` carries the broadcast scalar `weight`.
pub fn baracuda_kernels_binary_lerp_f32_run(
numel: i64, a: *const c_void, b: *const c_void, y: *mut c_void,
p: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_lerp_f32_can_implement` (baracuda kernels binary lerp f32 can implement).
pub fn baracuda_kernels_binary_lerp_f32_can_implement(
numel: i64,
a: *const c_void, b: *const c_void, y: *const c_void,
p: f32,
) -> i32;
/// `lerp` FW, f16.
///
/// # Safety
/// `a` / `b` / `y` point to `__half` storage.
pub fn baracuda_kernels_binary_lerp_f16_run(
numel: i64, a: *const c_void, b: *const c_void, y: *mut c_void,
p: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_lerp_f16_can_implement` (baracuda kernels binary lerp f16 can implement).
pub fn baracuda_kernels_binary_lerp_f16_can_implement(
numel: i64,
a: *const c_void, b: *const c_void, y: *const c_void,
p: f32,
) -> i32;
/// `lerp` FW, bf16.
///
/// # Safety
/// `a` / `b` / `y` point to `__nv_bfloat16` storage.
pub fn baracuda_kernels_binary_lerp_bf16_run(
numel: i64, a: *const c_void, b: *const c_void, y: *mut c_void,
p: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_lerp_bf16_can_implement` (baracuda kernels binary lerp bf16 can implement).
pub fn baracuda_kernels_binary_lerp_bf16_can_implement(
numel: i64,
a: *const c_void, b: *const c_void, y: *const c_void,
p: f32,
) -> i32;
/// `lerp` FW, f64. The f32 weight widens to f64 losslessly.
///
/// # Safety
/// `a` / `b` / `y` point to `double` storage.
pub fn baracuda_kernels_binary_lerp_f64_run(
numel: i64, a: *const c_void, b: *const c_void, y: *mut c_void,
p: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_lerp_f64_can_implement` (baracuda kernels binary lerp f64 can implement).
pub fn baracuda_kernels_binary_lerp_f64_can_implement(
numel: i64,
a: *const c_void, b: *const c_void, y: *const c_void,
p: f32,
) -> i32;
// ---- Lerp BW (no saves; param: weight) ----
/// `lerp` backward: `da = (1 - weight)·dy`, `db = weight·dy`, f32. No saves.
///
/// # Safety
/// `dy`, `da`, `db` device pointers.
pub fn baracuda_kernels_binary_lerp_backward_f32_run(
numel: i64, dy: *const c_void, da: *mut c_void, db: *mut c_void,
p: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_lerp_backward_f32_can_implement` (baracuda kernels binary lerp backward f32 can implement).
pub fn baracuda_kernels_binary_lerp_backward_f32_can_implement(
numel: i64,
dy: *const c_void, da: *const c_void, db: *const c_void,
p: f32,
) -> i32;
/// `lerp` BW, f16.
///
/// # Safety
/// All tensor pointers reference `__half` storage.
pub fn baracuda_kernels_binary_lerp_backward_f16_run(
numel: i64, dy: *const c_void, da: *mut c_void, db: *mut c_void,
p: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_lerp_backward_f16_can_implement` (baracuda kernels binary lerp backward f16 can implement).
pub fn baracuda_kernels_binary_lerp_backward_f16_can_implement(
numel: i64,
dy: *const c_void, da: *const c_void, db: *const c_void,
p: f32,
) -> i32;
/// `lerp` BW, bf16.
///
/// # Safety
/// All tensor pointers reference `__nv_bfloat16` storage.
pub fn baracuda_kernels_binary_lerp_backward_bf16_run(
numel: i64, dy: *const c_void, da: *mut c_void, db: *mut c_void,
p: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_lerp_backward_bf16_can_implement` (baracuda kernels binary lerp backward bf16 can implement).
pub fn baracuda_kernels_binary_lerp_backward_bf16_can_implement(
numel: i64,
dy: *const c_void, da: *const c_void, db: *const c_void,
p: f32,
) -> i32;
/// `lerp` BW, f64.
///
/// # Safety
/// All tensor pointers reference `double` storage.
pub fn baracuda_kernels_binary_lerp_backward_f64_run(
numel: i64, dy: *const c_void, da: *mut c_void, db: *mut c_void,
p: f32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_binary_lerp_backward_f64_can_implement` (baracuda kernels binary lerp backward f64 can implement).
pub fn baracuda_kernels_binary_lerp_backward_f64_can_implement(
numel: i64,
dy: *const c_void, da: *const c_void, db: *const c_void,
p: f32,
) -> i32;
}
// ============================================================================
// Random / cuRAND — Phase 4.5
// ============================================================================
//
// Host-API cuRAND bindings + custom bespoke kernels for Bernoulli /
// Dropout. cuRAND covers `Uniform` and `Normal` directly. Bernoulli is
// custom (uniform + threshold → Bool), Dropout is custom (uniform +
// threshold + scale, writes both `y` and `mask`).
//
// Linkage: `cargo:rustc-link-lib=dylib=curand` (added in build.rs). The
// system resolves `libcurand.so` on Linux and `curand64_*.dll` on Windows
// from the CUDA installation that ships them alongside cudart.
/// Opaque cuRAND generator handle. Treated as a stateful object owned by
/// safe Rust at the plan layer — never inspect its internals here.
#[allow(non_camel_case_types)]
pub type curandGenerator_t = *mut c_void;
/// `CURAND_RNG_PSEUDO_DEFAULT` — XORWOW pseudo-random generator. Adequate
/// for the dropout / sampling use cases this milestone targets; future
/// QRNG / Philox / MT19937 work can extend the descriptor surface.
pub const CURAND_RNG_PSEUDO_DEFAULT: i32 = 100;
/// `CURAND_STATUS_SUCCESS` — only success code. Any non-zero return from
/// the cuRAND host API is mapped to status `5` ("internal kernel error")
/// at the safe-plan layer.
pub const CURAND_STATUS_SUCCESS: i32 = 0;
unsafe extern "C" {
/// `curandCreateGenerator(generator, rng_type)`. Returns 0 on success.
///
/// # Safety
/// `generator` must point to writable storage for one `curandGenerator_t`.
pub fn curandCreateGenerator(
generator: *mut curandGenerator_t,
rng_type: i32,
) -> i32;
/// `curandSetPseudoRandomGeneratorSeed(generator, seed)`. Returns 0 on success.
///
/// # Safety
/// `generator` must be a valid handle returned by `curandCreateGenerator`.
pub fn curandSetPseudoRandomGeneratorSeed(
generator: curandGenerator_t,
seed: u64,
) -> i32;
/// `curandSetStream(generator, stream)`. Binds subsequent generator calls
/// to the given CUDA stream. Returns 0 on success.
///
/// # Safety
/// `generator` must be a valid handle; `stream` must be a valid CUDA stream
/// in the current context, or null for the default stream.
pub fn curandSetStream(generator: curandGenerator_t, stream: *mut c_void) -> i32;
/// `curandGenerateUniform(generator, ptr, n)` — writes `n` `float` samples
/// in `(0, 1]` to `ptr`. Returns 0 on success.
///
/// # Safety
/// `ptr` must point to at least `n * sizeof(f32)` writable device bytes.
pub fn curandGenerateUniform(
generator: curandGenerator_t,
ptr: *mut f32,
n: usize,
) -> i32;
/// `curandGenerateUniformDouble(generator, ptr, n)` — writes `n` `double`
/// samples in `(0, 1]` to `ptr`. Returns 0 on success.
///
/// # Safety
/// `ptr` must point to at least `n * sizeof(f64)` writable device bytes.
pub fn curandGenerateUniformDouble(
generator: curandGenerator_t,
ptr: *mut f64,
n: usize,
) -> i32;
/// `curandGenerateNormal(generator, ptr, n, mean, stddev)` — writes `n`
/// normally-distributed `float` samples to `ptr`. Note: cuRAND
/// requires `n` be even for the Box-Muller pair generator. Returns 0
/// on success.
///
/// # Safety
/// `ptr` must point to at least `n * sizeof(f32)` writable device bytes.
pub fn curandGenerateNormal(
generator: curandGenerator_t,
ptr: *mut f32,
n: usize,
mean: f32,
stddev: f32,
) -> i32;
/// `curandGenerateNormalDouble(generator, ptr, n, mean, stddev)`.
/// Same parity contract as `curandGenerateNormal`. Returns 0 on success.
///
/// # Safety
/// `ptr` must point to at least `n * sizeof(f64)` writable device bytes.
pub fn curandGenerateNormalDouble(
generator: curandGenerator_t,
ptr: *mut f64,
n: usize,
mean: f64,
stddev: f64,
) -> i32;
/// `curandDestroyGenerator(generator)`. Returns 0 on success.
///
/// # Safety
/// `generator` must be a valid handle returned by `curandCreateGenerator`
/// that has not been previously destroyed.
pub fn curandDestroyGenerator(generator: curandGenerator_t) -> i32;
}
// ----------------------------------------------------------------------------
// Bespoke random kernels — Bernoulli + Dropout
// ----------------------------------------------------------------------------
//
// Two custom kernels per dtype because cuRAND only generates uniform /
// normal directly:
//
// * `bernoulli_<dtype>` — reads a `float` uniform-rand buffer and a
// probability `p`; writes Bool output (`1` if rand < p else `0`).
// * `dropout_<dtype>` — reads input `x` + `float` uniform-rand buffer +
// `p` (drop probability); writes `y = mask · x / (1 - p)` and `mask`
// (`1` kept, `0` dropped). Caller saves `mask` for backward.
// * `dropout_backward_<dtype>` — reads `dy` + saved `mask` + `p`; writes
// `dx = mask · dy / (1 - p)`.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// `bernoulli` over a `float` uniform-rand buffer.
///
/// Writes one `Bool` (encoded as `uint8_t` 0/1) per output cell:
/// `y[i] = (rand[i] < p) ? 1 : 0`.
///
/// # Safety
/// `rand` points to `numel` `float` samples (caller-generated via
/// cuRAND); `y` points to `numel` `uint8_t` cells.
pub fn baracuda_kernels_bernoulli_run(
numel: i64,
p: f32,
rand: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_bernoulli_can_implement` (baracuda kernels bernoulli can implement).
pub fn baracuda_kernels_bernoulli_can_implement(
numel: i64,
p: f32,
rand: *const c_void,
y: *const c_void,
) -> i32;
/// Dropout forward (f32). Writes:
/// - `y[i] = (rand[i] < (1 - p)) ? x[i] * scale : 0`
/// - `mask[i] = (rand[i] < (1 - p)) ? 1 : 0` (encoded as `uint8_t`)
/// where `scale = 1 / (1 - p)`. Caller computes `scale` to keep the
/// kernel branch-free of the `p == 1` edge case.
///
/// # Safety
/// All tensor pointers reference device memory. `x` / `rand` / `y`
/// hold `f32`; `mask` is a packed Bool (`uint8_t`).
pub fn baracuda_kernels_dropout_f32_run(
numel: i64,
p: f32,
scale: f32,
x: *const c_void,
rand: *const c_void,
y: *mut c_void,
mask: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_dropout_f32_can_implement` (baracuda kernels dropout f32 can implement).
pub fn baracuda_kernels_dropout_f32_can_implement(
numel: i64,
p: f32,
scale: f32,
x: *const c_void,
rand: *const c_void,
y: *const c_void,
mask: *const c_void,
) -> i32;
/// Dropout forward (f64). Same shape as the f32 variant.
///
/// # Safety
/// `x` / `y` reference `double`; `rand` reads `float` samples (one
/// per output cell); `mask` is a packed Bool (`uint8_t`).
pub fn baracuda_kernels_dropout_f64_run(
numel: i64,
p: f32,
scale: f64,
x: *const c_void,
rand: *const c_void,
y: *mut c_void,
mask: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_dropout_f64_can_implement` (baracuda kernels dropout f64 can implement).
pub fn baracuda_kernels_dropout_f64_can_implement(
numel: i64,
p: f32,
scale: f32,
x: *const c_void,
rand: *const c_void,
y: *const c_void,
mask: *const c_void,
) -> i32;
/// Dropout backward (f32). Writes `dx[i] = dy[i] * mask[i] * scale`
/// where `scale = 1 / (1 - p)`.
///
/// # Safety
/// `dy` / `dx` reference `float`; `mask` is a packed Bool (`uint8_t`).
pub fn baracuda_kernels_dropout_backward_f32_run(
numel: i64,
scale: f32,
dy: *const c_void,
mask: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_dropout_backward_f32_can_implement` (baracuda kernels dropout backward f32 can implement).
pub fn baracuda_kernels_dropout_backward_f32_can_implement(
numel: i64,
scale: f32,
dy: *const c_void,
mask: *const c_void,
dx: *const c_void,
) -> i32;
/// Dropout backward (f64).
///
/// # Safety
/// `dy` / `dx` reference `double`; `mask` is a packed Bool (`uint8_t`).
pub fn baracuda_kernels_dropout_backward_f64_run(
numel: i64,
scale: f64,
dy: *const c_void,
mask: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_dropout_backward_f64_can_implement` (baracuda kernels dropout backward f64 can implement).
pub fn baracuda_kernels_dropout_backward_f64_can_implement(
numel: i64,
scale: f32,
dy: *const c_void,
mask: *const c_void,
dx: *const c_void,
) -> i32;
/// In-place affine `y = scale * y + offset` (f32). Used by the
/// safe-plan layer to remap a cuRAND uniform-(0, 1] buffer into
/// `Uniform(low, high]`.
///
/// # Safety
/// `y` points to `numel` `float` device cells.
pub fn baracuda_kernels_affine_inplace_f32_run(
numel: i64,
scale: f32,
offset: f32,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_affine_inplace_f32`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_affine_inplace_f32_can_implement(
numel: i64,
scale: f32,
offset: f32,
y: *const c_void,
) -> i32;
/// In-place affine `y = scale * y + offset` (f64).
///
/// # Safety
/// `y` points to `numel` `double` device cells.
pub fn baracuda_kernels_affine_inplace_f64_run(
numel: i64,
scale: f64,
offset: f64,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_affine_inplace_f64`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_affine_inplace_f64_can_implement(
numel: i64,
scale: f64,
offset: f64,
y: *const c_void,
) -> i32;
/// In-place affine `y = scale * y + offset` (bf16). **Phase 61** —
/// added for Fuel's INPLACE_AFFINE op family completion (bf16/f16
/// weight-decay scaling, `Op::AddScalar` / `Op::MulScalar` on bf16
/// model weights).
///
/// `scale` and `offset` are always `f32` regardless of storage
/// dtype — matches the forward `affine_bf16_run` convention and
/// avoids passing `__nv_bfloat16` by value through the C ABI.
/// Internal compute happens at f32; storage at `__nv_bfloat16`.
///
/// # Safety
/// `y` points to `numel * 2` bytes of device memory holding
/// `__nv_bfloat16` values. Same-pointer-only contract — there is no
/// `x` input.
pub fn baracuda_kernels_affine_inplace_bf16_run(
numel: i64,
scale: f32,
offset: f32,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_affine_inplace_bf16`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_affine_inplace_bf16_can_implement(
numel: i64,
scale: f32,
offset: f32,
y: *const c_void,
) -> i32;
/// In-place affine `y = scale * y + offset` (f16). **Phase 61** —
/// added for Fuel's INPLACE_AFFINE op family completion.
///
/// `scale` and `offset` are always `f32` regardless of storage
/// dtype — matches the forward `affine_f16_run` convention. Internal
/// compute happens at f32; storage at `__half`.
///
/// # Safety
/// `y` points to `numel * 2` bytes of device memory holding
/// `__half` values.
pub fn baracuda_kernels_affine_inplace_f16_run(
numel: i64,
scale: f32,
offset: f32,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_affine_inplace_f16`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_affine_inplace_f16_can_implement(
numel: i64,
scale: f32,
offset: f32,
y: *const c_void,
) -> i32;
// ========================================================================
// Phase 62 — int dtype contig in-place backfill (i32 / i64 / u8 / i8).
// Matches the forward `affine_<int>_run` dtype set. Scalars are
// dtype-typed (i32 in-place takes i32 scale/offset; etc.) — matches
// the forward int affine convention. Integer overflow wraps per
// C++20 two's-complement modular semantics (same as forward int
// affine, see Phase 37 reduce-int integer-accumulator gotcha).
// ========================================================================
/// In-place affine `y = scale * y + offset` (i32). Phase 62.
///
/// # Safety
/// `y` points to `numel * 4` bytes holding `int32_t` values.
pub fn baracuda_kernels_affine_inplace_i32_run(
numel: i64,
scale: i32,
offset: i32,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_affine_inplace_i32`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_affine_inplace_i32_can_implement(
numel: i64,
scale: i32,
offset: i32,
y: *const c_void,
) -> i32;
/// In-place affine `y = scale * y + offset` (i64). Phase 62.
///
/// # Safety
/// `y` points to `numel * 8` bytes holding `int64_t` values.
pub fn baracuda_kernels_affine_inplace_i64_run(
numel: i64,
scale: i64,
offset: i64,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_affine_inplace_i64`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_affine_inplace_i64_can_implement(
numel: i64,
scale: i64,
offset: i64,
y: *const c_void,
) -> i32;
/// In-place affine `y = scale * y + offset` (u8). Phase 62.
///
/// # Safety
/// `y` points to `numel` bytes holding `uint8_t` values.
pub fn baracuda_kernels_affine_inplace_u8_run(
numel: i64,
scale: u8,
offset: u8,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_affine_inplace_u8`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_affine_inplace_u8_can_implement(
numel: i64,
scale: u8,
offset: u8,
y: *const c_void,
) -> i32;
/// In-place affine `y = scale * y + offset` (i8). Phase 62.
///
/// # Safety
/// `y` points to `numel` bytes holding `int8_t` values.
pub fn baracuda_kernels_affine_inplace_i8_run(
numel: i64,
scale: i8,
offset: i8,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_affine_inplace_i8`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_affine_inplace_i8_can_implement(
numel: i64,
scale: i8,
offset: i8,
y: *const c_void,
) -> i32;
// ========================================================================
// Phase 62 — strided in-place affine. ABI extends the contig
// version with `(rank, shape, stride_y)`; rank ≤ 8 enforced.
// Single stride array — caller is responsible for ensuring it's a
// valid permutation of element offsets (no zero strides on output;
// no two linear indices mapping to the same `y` cell). The kernel
// does no validation.
//
// If a caller is using this to replace a forward
// `affine_<dtype>_strided_run` with `x_ptr == y_ptr` semantics, the
// additional contract is `stride_x == stride_y` (use
// `baracuda_kernels_types::strides_equal` to check). With the
// contract honored, the kernel is structurally aliasing-safe (each
// thread reads its own `off_y` cell once, then writes back to the
// same cell — same per-thread access pattern as the contig
// affine_inplace family).
//
// Dtype set matches the forward `affine_<dtype>_strided_run` set:
// f32 / f64 / i32 / i64 / u8 / bf16 / f16 (no i8 strided forward,
// so no i8 strided in-place).
// ========================================================================
/// In-place affine `y[off] = scale * y[off] + offset` over a strided
/// view (f32). Phase 62.
///
/// # Safety
/// `y` points to `numel * 4` bytes holding `float` values. `shape`
/// and `stride_y` are host-side arrays of length `rank` (≤ 8).
pub fn baracuda_kernels_affine_inplace_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_y: *const i64,
scale: f32,
offset: f32,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_affine_inplace_f32_strided`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_affine_inplace_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_y: *const i64,
scale: f32,
offset: f32,
y: *const c_void,
) -> i32;
/// Strided in-place affine (f64). Phase 62.
///
/// # Safety
/// `y` points to `numel * 8` bytes holding `double` values.
pub fn baracuda_kernels_affine_inplace_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_y: *const i64,
scale: f64,
offset: f64,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_affine_inplace_f64_strided`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_affine_inplace_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_y: *const i64,
scale: f64,
offset: f64,
y: *const c_void,
) -> i32;
/// Strided in-place affine (i32). Phase 62.
///
/// # Safety
/// `y` points to `numel * 4` bytes holding `int32_t` values.
pub fn baracuda_kernels_affine_inplace_i32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_y: *const i64,
scale: i32,
offset: i32,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_affine_inplace_i32_strided`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_affine_inplace_i32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_y: *const i64,
scale: i32,
offset: i32,
y: *const c_void,
) -> i32;
/// Strided in-place affine (i64). Phase 62.
///
/// # Safety
/// `y` points to `numel * 8` bytes holding `int64_t` values.
pub fn baracuda_kernels_affine_inplace_i64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_y: *const i64,
scale: i64,
offset: i64,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_affine_inplace_i64_strided`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_affine_inplace_i64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_y: *const i64,
scale: i64,
offset: i64,
y: *const c_void,
) -> i32;
/// Strided in-place affine (u8). Phase 62.
///
/// # Safety
/// `y` points to `numel` bytes holding `uint8_t` values.
pub fn baracuda_kernels_affine_inplace_u8_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_y: *const i64,
scale: u8,
offset: u8,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_affine_inplace_u8_strided`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_affine_inplace_u8_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_y: *const i64,
scale: u8,
offset: u8,
y: *const c_void,
) -> i32;
/// Strided in-place affine (bf16; f32 scalars). Phase 62.
///
/// # Safety
/// `y` points to `numel * 2` bytes holding `__nv_bfloat16` values.
pub fn baracuda_kernels_affine_inplace_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_y: *const i64,
scale: f32,
offset: f32,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_affine_inplace_bf16_strided`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_affine_inplace_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_y: *const i64,
scale: f32,
offset: f32,
y: *const c_void,
) -> i32;
/// Strided in-place affine (f16; f32 scalars). Phase 62.
///
/// # Safety
/// `y` points to `numel * 2` bytes holding `__half` values.
pub fn baracuda_kernels_affine_inplace_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_y: *const i64,
scale: f32,
offset: f32,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_affine_inplace_f16_strided`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_affine_inplace_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_y: *const i64,
scale: f32,
offset: f32,
y: *const c_void,
) -> i32;
}
// ============================================================================
// Phase 6.1 — Attention positional encodings (Category K).
//
// RoPE (rotary position embedding) + ALiBi (attention-with-linear-biases),
// FW + BW × 4 FP dtypes. RoPE rotates pairs (2i, 2i+1) of a [B, H, S, D]
// Q/K tensor by per-position angles θ = pos · base^(-2i/D); BW reverses
// the trig sign (rotation by -θ). ALiBi adds slope[h]·(j-i) to score
// cell (b, h, i, j); BW is pass-through dA copy + per-head deterministic
// warp-shuffle reduction for dslope.
// ============================================================================
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// RoPE FW, f32. Input/output are [B, H, S, D] contiguous row-major;
/// `head_dim` (D) must be even. When `pos_default_flag != 0`, the
/// kernel ignores `positions` and uses position index = sequence
/// index; otherwise `positions` is `int64_t[seq]`.
pub fn baracuda_kernels_rope_f32_run(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
base: f32,
pos_default_flag: i32,
x: *const c_void,
positions: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `rope_f32`. Host-side only.
pub fn baracuda_kernels_rope_f32_can_implement(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
base: f32,
pos_default_flag: i32,
x: *const c_void,
positions: *const c_void,
y: *const c_void,
) -> i32;
/// RoPE FW, f16 (f32 trig detour internally).
pub fn baracuda_kernels_rope_f16_run(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
base: f32,
pos_default_flag: i32,
x: *const c_void,
positions: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `rope_f16`. Host-side only.
pub fn baracuda_kernels_rope_f16_can_implement(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
base: f32,
pos_default_flag: i32,
x: *const c_void,
positions: *const c_void,
y: *const c_void,
) -> i32;
/// RoPE FW, bf16.
pub fn baracuda_kernels_rope_bf16_run(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
base: f32,
pos_default_flag: i32,
x: *const c_void,
positions: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `rope_bf16`. Host-side only.
pub fn baracuda_kernels_rope_bf16_can_implement(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
base: f32,
pos_default_flag: i32,
x: *const c_void,
positions: *const c_void,
y: *const c_void,
) -> i32;
/// RoPE FW, f64.
pub fn baracuda_kernels_rope_f64_run(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
base: f32,
pos_default_flag: i32,
x: *const c_void,
positions: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `rope_f64`. Host-side only.
pub fn baracuda_kernels_rope_f64_can_implement(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
base: f32,
pos_default_flag: i32,
x: *const c_void,
positions: *const c_void,
y: *const c_void,
) -> i32;
/// RoPE BW, f32. Same shape as FW; computes `dx` from `dy` by
/// rotation through `-θ`.
pub fn baracuda_kernels_rope_backward_f32_run(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
base: f32,
pos_default_flag: i32,
dy: *const c_void,
positions: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `rope_backward_f32`. Host-side only.
pub fn baracuda_kernels_rope_backward_f32_can_implement(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
base: f32,
pos_default_flag: i32,
dy: *const c_void,
positions: *const c_void,
dx: *const c_void,
) -> i32;
/// RoPE BW, f16.
pub fn baracuda_kernels_rope_backward_f16_run(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
base: f32,
pos_default_flag: i32,
dy: *const c_void,
positions: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `rope_backward_f16`. Host-side only.
pub fn baracuda_kernels_rope_backward_f16_can_implement(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
base: f32,
pos_default_flag: i32,
dy: *const c_void,
positions: *const c_void,
dx: *const c_void,
) -> i32;
/// RoPE BW, bf16.
pub fn baracuda_kernels_rope_backward_bf16_run(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
base: f32,
pos_default_flag: i32,
dy: *const c_void,
positions: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `rope_backward_bf16`. Host-side only.
pub fn baracuda_kernels_rope_backward_bf16_can_implement(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
base: f32,
pos_default_flag: i32,
dy: *const c_void,
positions: *const c_void,
dx: *const c_void,
) -> i32;
/// RoPE BW, f64.
pub fn baracuda_kernels_rope_backward_f64_run(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
base: f32,
pos_default_flag: i32,
dy: *const c_void,
positions: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `rope_backward_f64`. Host-side only.
pub fn baracuda_kernels_rope_backward_f64_can_implement(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
base: f32,
pos_default_flag: i32,
dy: *const c_void,
positions: *const c_void,
dx: *const c_void,
) -> i32;
// ---- RoPE strided FW + BW × 4 dtypes — Phase 14.4 ----
//
// Outer dims (batch, heads, seq) carry signed-i64 element strides.
// The innermost `head_dim` axis is implicitly stride=1 (RoPE
// rotates adjacent pairs (2i, 2i+1) which must sit next to each
// other in memory). The Rust plan layer rejects any non-unit
// stride on head_dim before crossing the FFI.
/// RoPE FW strided, f32.
pub fn baracuda_kernels_rope_f32_strided_run(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
stride_x_b: i64, stride_x_h: i64, stride_x_s: i64,
stride_y_b: i64, stride_y_h: i64, stride_y_s: i64,
base: f32,
pos_default_flag: i32,
x: *const c_void,
positions: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `rope_f32_strided`. Host-side only.
pub fn baracuda_kernels_rope_f32_strided_can_implement(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
stride_x_b: i64, stride_x_h: i64, stride_x_s: i64,
stride_y_b: i64, stride_y_h: i64, stride_y_s: i64,
base: f32,
pos_default_flag: i32,
x: *const c_void,
positions: *const c_void,
y: *const c_void,
) -> i32;
/// RoPE FW strided, f16.
pub fn baracuda_kernels_rope_f16_strided_run(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
stride_x_b: i64, stride_x_h: i64, stride_x_s: i64,
stride_y_b: i64, stride_y_h: i64, stride_y_s: i64,
base: f32,
pos_default_flag: i32,
x: *const c_void,
positions: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `rope_f16_strided`. Host-side only.
pub fn baracuda_kernels_rope_f16_strided_can_implement(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
stride_x_b: i64, stride_x_h: i64, stride_x_s: i64,
stride_y_b: i64, stride_y_h: i64, stride_y_s: i64,
base: f32,
pos_default_flag: i32,
x: *const c_void,
positions: *const c_void,
y: *const c_void,
) -> i32;
/// RoPE FW strided, bf16.
pub fn baracuda_kernels_rope_bf16_strided_run(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
stride_x_b: i64, stride_x_h: i64, stride_x_s: i64,
stride_y_b: i64, stride_y_h: i64, stride_y_s: i64,
base: f32,
pos_default_flag: i32,
x: *const c_void,
positions: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `rope_bf16_strided`. Host-side only.
pub fn baracuda_kernels_rope_bf16_strided_can_implement(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
stride_x_b: i64, stride_x_h: i64, stride_x_s: i64,
stride_y_b: i64, stride_y_h: i64, stride_y_s: i64,
base: f32,
pos_default_flag: i32,
x: *const c_void,
positions: *const c_void,
y: *const c_void,
) -> i32;
/// RoPE FW strided, f64.
pub fn baracuda_kernels_rope_f64_strided_run(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
stride_x_b: i64, stride_x_h: i64, stride_x_s: i64,
stride_y_b: i64, stride_y_h: i64, stride_y_s: i64,
base: f32,
pos_default_flag: i32,
x: *const c_void,
positions: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `rope_f64_strided`. Host-side only.
pub fn baracuda_kernels_rope_f64_strided_can_implement(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
stride_x_b: i64, stride_x_h: i64, stride_x_s: i64,
stride_y_b: i64, stride_y_h: i64, stride_y_s: i64,
base: f32,
pos_default_flag: i32,
x: *const c_void,
positions: *const c_void,
y: *const c_void,
) -> i32;
/// RoPE BW strided, f32. Strides apply to `dy` (input) and `dx` (output).
pub fn baracuda_kernels_rope_backward_f32_strided_run(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
stride_dy_b: i64, stride_dy_h: i64, stride_dy_s: i64,
stride_dx_b: i64, stride_dx_h: i64, stride_dx_s: i64,
base: f32,
pos_default_flag: i32,
dy: *const c_void,
positions: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `rope_backward_f32_strided`. Host-side only.
pub fn baracuda_kernels_rope_backward_f32_strided_can_implement(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
stride_dy_b: i64, stride_dy_h: i64, stride_dy_s: i64,
stride_dx_b: i64, stride_dx_h: i64, stride_dx_s: i64,
base: f32,
pos_default_flag: i32,
dy: *const c_void,
positions: *const c_void,
dx: *const c_void,
) -> i32;
/// RoPE BW strided, f16.
pub fn baracuda_kernels_rope_backward_f16_strided_run(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
stride_dy_b: i64, stride_dy_h: i64, stride_dy_s: i64,
stride_dx_b: i64, stride_dx_h: i64, stride_dx_s: i64,
base: f32,
pos_default_flag: i32,
dy: *const c_void,
positions: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `rope_backward_f16_strided`. Host-side only.
pub fn baracuda_kernels_rope_backward_f16_strided_can_implement(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
stride_dy_b: i64, stride_dy_h: i64, stride_dy_s: i64,
stride_dx_b: i64, stride_dx_h: i64, stride_dx_s: i64,
base: f32,
pos_default_flag: i32,
dy: *const c_void,
positions: *const c_void,
dx: *const c_void,
) -> i32;
/// RoPE BW strided, bf16.
pub fn baracuda_kernels_rope_backward_bf16_strided_run(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
stride_dy_b: i64, stride_dy_h: i64, stride_dy_s: i64,
stride_dx_b: i64, stride_dx_h: i64, stride_dx_s: i64,
base: f32,
pos_default_flag: i32,
dy: *const c_void,
positions: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `rope_backward_bf16_strided`. Host-side only.
pub fn baracuda_kernels_rope_backward_bf16_strided_can_implement(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
stride_dy_b: i64, stride_dy_h: i64, stride_dy_s: i64,
stride_dx_b: i64, stride_dx_h: i64, stride_dx_s: i64,
base: f32,
pos_default_flag: i32,
dy: *const c_void,
positions: *const c_void,
dx: *const c_void,
) -> i32;
/// RoPE BW strided, f64.
pub fn baracuda_kernels_rope_backward_f64_strided_run(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
stride_dy_b: i64, stride_dy_h: i64, stride_dy_s: i64,
stride_dx_b: i64, stride_dx_h: i64, stride_dx_s: i64,
base: f32,
pos_default_flag: i32,
dy: *const c_void,
positions: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `rope_backward_f64_strided`. Host-side only.
pub fn baracuda_kernels_rope_backward_f64_strided_can_implement(
batch: i32,
heads: i32,
seq: i32,
head_dim: i32,
stride_dy_b: i64, stride_dy_h: i64, stride_dy_s: i64,
stride_dx_b: i64, stride_dx_h: i64, stride_dx_s: i64,
base: f32,
pos_default_flag: i32,
dy: *const c_void,
positions: *const c_void,
dx: *const c_void,
) -> i32;
// ---- Phase 36 (Fuel ask Gap 2) — RoPE apply with caller-supplied
// cos/sin tables ----
//
// Coexists with the existing `rope_<dt>_run` family (which derives
// θ internally from `pos · base^(-2i/D)`). The apply variant is the
// LLaMA-style extended-context API — YaRN, NTK, and dynamic-scaling
// schedules pre-bake the trig values; this kernel just consumes
// them.
//
// Flat layout — `x` / `y` are `[bh, td]` with `bh = batch * heads`
// and `td = seq * head_dim per (batch, head)`. `cos` / `sin` are
// always f32 over the FFI (regardless of operand dtype); f16/bf16
// detour through f32 internally, f64 promotes the f32 tables to
// double at load. `stride_b = 0` means the cos/sin table is shared
// across all `bh` rows; `stride_b = td/2` means one cos/sin table
// per `bh` row.
/// RoPE apply FW, f32. Cos/sin tables provided by caller.
///
/// **Aliasing (Phase 64) — NOT in-place safe**: aliasing `y` with
/// `x` is UNSAFE. RoPE rotates pairs `(x[even], x[odd])` via a
/// 2×2 rotation matrix, with two threads per pair both reading
/// both pair elements. If the even thread runs first and writes
/// `y[even]`, the odd thread then reads `y[even]` (= rotated
/// value) instead of the original `x[even]`, producing the wrong
/// rotation. Callers that need an in-place RoPE must do it via
/// an explicit double-buffer (or accept a fresh output buffer).
/// Aliasing `y` with `cos_tab` / `sin_tab` is also unsafe (those
/// tables are read across multiple threads). This contract is
/// stable across baracuda versions and applies to every
/// `rope_apply_<dt>_run`, `rope_apply_interleaved_<dt>_run`,
/// and `rope_apply_thd_<dt>_run` variant.
pub fn baracuda_kernels_rope_apply_f32_run(
bh: i32,
td: i32,
d: i32,
stride_b: i32,
x: *const c_void,
cos_tab: *const c_void,
sin_tab: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `rope_apply_f32`. Host-side only.
pub fn baracuda_kernels_rope_apply_f32_can_implement(
bh: i32, td: i32, d: i32, stride_b: i32,
) -> i32;
/// RoPE apply FW, f16 (f32 trig table, f32 multiply detour).
pub fn baracuda_kernels_rope_apply_f16_run(
bh: i32,
td: i32,
d: i32,
stride_b: i32,
x: *const c_void,
cos_tab: *const c_void,
sin_tab: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rope_apply_f16_can_implement` (baracuda kernels rope apply f16 can implement).
pub fn baracuda_kernels_rope_apply_f16_can_implement(
bh: i32, td: i32, d: i32, stride_b: i32,
) -> i32;
/// RoPE apply FW, bf16 (f32 trig table, f32 multiply detour).
pub fn baracuda_kernels_rope_apply_bf16_run(
bh: i32,
td: i32,
d: i32,
stride_b: i32,
x: *const c_void,
cos_tab: *const c_void,
sin_tab: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rope_apply_bf16_can_implement` (baracuda kernels rope apply bf16 can implement).
pub fn baracuda_kernels_rope_apply_bf16_can_implement(
bh: i32, td: i32, d: i32, stride_b: i32,
) -> i32;
/// RoPE apply FW, f64 (f32 trig table promoted to double at load).
pub fn baracuda_kernels_rope_apply_f64_run(
bh: i32,
td: i32,
d: i32,
stride_b: i32,
x: *const c_void,
cos_tab: *const c_void,
sin_tab: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rope_apply_f64_can_implement` (baracuda kernels rope apply f64 can implement).
pub fn baracuda_kernels_rope_apply_f64_can_implement(
bh: i32, td: i32, d: i32, stride_b: i32,
) -> i32;
/// RoPE apply BW, f32. Same cos/sin tables as FW; orthogonal-rotation reverse.
pub fn baracuda_kernels_rope_apply_backward_f32_run(
bh: i32,
td: i32,
d: i32,
stride_b: i32,
dy: *const c_void,
cos_tab: *const c_void,
sin_tab: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rope_apply_backward_f32_can_implement` (baracuda kernels rope apply backward f32 can implement).
pub fn baracuda_kernels_rope_apply_backward_f32_can_implement(
bh: i32, td: i32, d: i32, stride_b: i32,
) -> i32;
/// RoPE apply BW, f16.
pub fn baracuda_kernels_rope_apply_backward_f16_run(
bh: i32,
td: i32,
d: i32,
stride_b: i32,
dy: *const c_void,
cos_tab: *const c_void,
sin_tab: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rope_apply_backward_f16_can_implement` (baracuda kernels rope apply backward f16 can implement).
pub fn baracuda_kernels_rope_apply_backward_f16_can_implement(
bh: i32, td: i32, d: i32, stride_b: i32,
) -> i32;
/// RoPE apply BW, bf16.
pub fn baracuda_kernels_rope_apply_backward_bf16_run(
bh: i32,
td: i32,
d: i32,
stride_b: i32,
dy: *const c_void,
cos_tab: *const c_void,
sin_tab: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rope_apply_backward_bf16_can_implement` (baracuda kernels rope apply backward bf16 can implement).
pub fn baracuda_kernels_rope_apply_backward_bf16_can_implement(
bh: i32, td: i32, d: i32, stride_b: i32,
) -> i32;
/// RoPE apply BW, f64.
pub fn baracuda_kernels_rope_apply_backward_f64_run(
bh: i32,
td: i32,
d: i32,
stride_b: i32,
dy: *const c_void,
cos_tab: *const c_void,
sin_tab: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rope_apply_backward_f64_can_implement` (baracuda kernels rope apply backward f64 can implement).
pub fn baracuda_kernels_rope_apply_backward_f64_can_implement(
bh: i32, td: i32, d: i32, stride_b: i32,
) -> i32;
// ----------------------------------------------------------------
// Phase 41 (Fuel Phase 6c.4 Gap 7) — RoPE apply INTERLEAVED variant.
// ----------------------------------------------------------------
//
// Same FFI shape and semantics as the Phase 36 `rope_apply_<dt>`
// family — caller-supplied f32 cos/sin tables, `[bh, td]` flat
// operand layout. The interleaved name pins the pair convention
// `(2k, 2k+1)` with cos/sin indexed by `pair = dim_idx >> 1`
// (the same pairing already implemented by `rope_apply_<dt>_run`).
// Exposed as a separate symbol so callers using the `RotaryEmbI`
// API can drop the PTX module and link directly against baracuda.
/// RoPE apply interleaved FW, f32.
pub fn baracuda_kernels_rope_apply_interleaved_f32_run(
bh: i32,
td: i32,
d: i32,
stride_b: i32,
x: *const c_void,
cos_tab: *const c_void,
sin_tab: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rope_apply_interleaved_f32_can_implement` (baracuda kernels rope apply interleaved f32 can implement).
pub fn baracuda_kernels_rope_apply_interleaved_f32_can_implement(
bh: i32, td: i32, d: i32, stride_b: i32,
) -> i32;
/// RoPE apply interleaved FW, f16.
pub fn baracuda_kernels_rope_apply_interleaved_f16_run(
bh: i32,
td: i32,
d: i32,
stride_b: i32,
x: *const c_void,
cos_tab: *const c_void,
sin_tab: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rope_apply_interleaved_f16_can_implement` (baracuda kernels rope apply interleaved f16 can implement).
pub fn baracuda_kernels_rope_apply_interleaved_f16_can_implement(
bh: i32, td: i32, d: i32, stride_b: i32,
) -> i32;
/// RoPE apply interleaved FW, bf16.
pub fn baracuda_kernels_rope_apply_interleaved_bf16_run(
bh: i32,
td: i32,
d: i32,
stride_b: i32,
x: *const c_void,
cos_tab: *const c_void,
sin_tab: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rope_apply_interleaved_bf16_can_implement` (baracuda kernels rope apply interleaved bf16 can implement).
pub fn baracuda_kernels_rope_apply_interleaved_bf16_can_implement(
bh: i32, td: i32, d: i32, stride_b: i32,
) -> i32;
/// RoPE apply interleaved FW, f64.
pub fn baracuda_kernels_rope_apply_interleaved_f64_run(
bh: i32,
td: i32,
d: i32,
stride_b: i32,
x: *const c_void,
cos_tab: *const c_void,
sin_tab: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rope_apply_interleaved_f64_can_implement` (baracuda kernels rope apply interleaved f64 can implement).
pub fn baracuda_kernels_rope_apply_interleaved_f64_can_implement(
bh: i32, td: i32, d: i32, stride_b: i32,
) -> i32;
/// RoPE apply interleaved BW, f32.
pub fn baracuda_kernels_rope_apply_interleaved_backward_f32_run(
bh: i32,
td: i32,
d: i32,
stride_b: i32,
dy: *const c_void,
cos_tab: *const c_void,
sin_tab: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rope_apply_interleaved_backward_f32_can_implement` (baracuda kernels rope apply interleaved backward f32 can implement).
pub fn baracuda_kernels_rope_apply_interleaved_backward_f32_can_implement(
bh: i32, td: i32, d: i32, stride_b: i32,
) -> i32;
/// RoPE apply interleaved BW, f16.
pub fn baracuda_kernels_rope_apply_interleaved_backward_f16_run(
bh: i32,
td: i32,
d: i32,
stride_b: i32,
dy: *const c_void,
cos_tab: *const c_void,
sin_tab: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rope_apply_interleaved_backward_f16_can_implement` (baracuda kernels rope apply interleaved backward f16 can implement).
pub fn baracuda_kernels_rope_apply_interleaved_backward_f16_can_implement(
bh: i32, td: i32, d: i32, stride_b: i32,
) -> i32;
/// RoPE apply interleaved BW, bf16.
pub fn baracuda_kernels_rope_apply_interleaved_backward_bf16_run(
bh: i32,
td: i32,
d: i32,
stride_b: i32,
dy: *const c_void,
cos_tab: *const c_void,
sin_tab: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rope_apply_interleaved_backward_bf16_can_implement` (baracuda kernels rope apply interleaved backward bf16 can implement).
pub fn baracuda_kernels_rope_apply_interleaved_backward_bf16_can_implement(
bh: i32, td: i32, d: i32, stride_b: i32,
) -> i32;
/// RoPE apply interleaved BW, f64.
pub fn baracuda_kernels_rope_apply_interleaved_backward_f64_run(
bh: i32,
td: i32,
d: i32,
stride_b: i32,
dy: *const c_void,
cos_tab: *const c_void,
sin_tab: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rope_apply_interleaved_backward_f64_can_implement` (baracuda kernels rope apply interleaved backward f64 can implement).
pub fn baracuda_kernels_rope_apply_interleaved_backward_f64_can_implement(
bh: i32, td: i32, d: i32, stride_b: i32,
) -> i32;
// ----------------------------------------------------------------
// Phase 41 (Fuel Phase 6c.4 Gap 8) — RoPE apply THD-layout variant.
// ----------------------------------------------------------------
//
// Operand layout `[T, H, D]` (T packs batch * seq) instead of the
// canonical `[B, H, T, D]`. Per-cell addressing
// `x[t * (H * D) + h * D + dim]`. cos/sin tables remain f32 over
// the FFI; layout is `cs[t * stride_b + pair]` with `stride_b ==
// D/2` per-t tables or `stride_b == 0` for a single shared `[D/2]`
// table. Pair convention `(2k, 2k+1)` matches the canonical apply.
/// RoPE apply THD FW, f32.
pub fn baracuda_kernels_rope_apply_thd_f32_run(
t_outer: i32,
h_heads: i32,
d: i32,
stride_b: i32,
x: *const c_void,
cos_tab: *const c_void,
sin_tab: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rope_apply_thd_f32_can_implement` (baracuda kernels rope apply thd f32 can implement).
pub fn baracuda_kernels_rope_apply_thd_f32_can_implement(
t_outer: i32, h_heads: i32, d: i32, stride_b: i32,
) -> i32;
/// RoPE apply THD FW, f16.
pub fn baracuda_kernels_rope_apply_thd_f16_run(
t_outer: i32,
h_heads: i32,
d: i32,
stride_b: i32,
x: *const c_void,
cos_tab: *const c_void,
sin_tab: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rope_apply_thd_f16_can_implement` (baracuda kernels rope apply thd f16 can implement).
pub fn baracuda_kernels_rope_apply_thd_f16_can_implement(
t_outer: i32, h_heads: i32, d: i32, stride_b: i32,
) -> i32;
/// RoPE apply THD FW, bf16.
pub fn baracuda_kernels_rope_apply_thd_bf16_run(
t_outer: i32,
h_heads: i32,
d: i32,
stride_b: i32,
x: *const c_void,
cos_tab: *const c_void,
sin_tab: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rope_apply_thd_bf16_can_implement` (baracuda kernels rope apply thd bf16 can implement).
pub fn baracuda_kernels_rope_apply_thd_bf16_can_implement(
t_outer: i32, h_heads: i32, d: i32, stride_b: i32,
) -> i32;
/// RoPE apply THD FW, f64.
pub fn baracuda_kernels_rope_apply_thd_f64_run(
t_outer: i32,
h_heads: i32,
d: i32,
stride_b: i32,
x: *const c_void,
cos_tab: *const c_void,
sin_tab: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rope_apply_thd_f64_can_implement` (baracuda kernels rope apply thd f64 can implement).
pub fn baracuda_kernels_rope_apply_thd_f64_can_implement(
t_outer: i32, h_heads: i32, d: i32, stride_b: i32,
) -> i32;
/// RoPE apply THD BW, f32.
pub fn baracuda_kernels_rope_apply_thd_backward_f32_run(
t_outer: i32,
h_heads: i32,
d: i32,
stride_b: i32,
dy: *const c_void,
cos_tab: *const c_void,
sin_tab: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rope_apply_thd_backward_f32_can_implement` (baracuda kernels rope apply thd backward f32 can implement).
pub fn baracuda_kernels_rope_apply_thd_backward_f32_can_implement(
t_outer: i32, h_heads: i32, d: i32, stride_b: i32,
) -> i32;
/// RoPE apply THD BW, f16.
pub fn baracuda_kernels_rope_apply_thd_backward_f16_run(
t_outer: i32,
h_heads: i32,
d: i32,
stride_b: i32,
dy: *const c_void,
cos_tab: *const c_void,
sin_tab: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rope_apply_thd_backward_f16_can_implement` (baracuda kernels rope apply thd backward f16 can implement).
pub fn baracuda_kernels_rope_apply_thd_backward_f16_can_implement(
t_outer: i32, h_heads: i32, d: i32, stride_b: i32,
) -> i32;
/// RoPE apply THD BW, bf16.
pub fn baracuda_kernels_rope_apply_thd_backward_bf16_run(
t_outer: i32,
h_heads: i32,
d: i32,
stride_b: i32,
dy: *const c_void,
cos_tab: *const c_void,
sin_tab: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rope_apply_thd_backward_bf16_can_implement` (baracuda kernels rope apply thd backward bf16 can implement).
pub fn baracuda_kernels_rope_apply_thd_backward_bf16_can_implement(
t_outer: i32, h_heads: i32, d: i32, stride_b: i32,
) -> i32;
/// RoPE apply THD BW, f64.
pub fn baracuda_kernels_rope_apply_thd_backward_f64_run(
t_outer: i32,
h_heads: i32,
d: i32,
stride_b: i32,
dy: *const c_void,
cos_tab: *const c_void,
sin_tab: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_rope_apply_thd_backward_f64_can_implement` (baracuda kernels rope apply thd backward f64 can implement).
pub fn baracuda_kernels_rope_apply_thd_backward_f64_can_implement(
t_outer: i32, h_heads: i32, d: i32, stride_b: i32,
) -> i32;
/// ALiBi FW, f32. `y[b, h, i, j] = scores[b, h, i, j] + slopes[h] · (j - i)`.
pub fn baracuda_kernels_alibi_f32_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
scores: *const c_void,
slopes: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `alibi_f32`. Host-side only.
pub fn baracuda_kernels_alibi_f32_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
scores: *const c_void,
slopes: *const c_void,
y: *const c_void,
) -> i32;
/// ALiBi FW, f16.
pub fn baracuda_kernels_alibi_f16_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
scores: *const c_void,
slopes: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `alibi_f16`. Host-side only.
pub fn baracuda_kernels_alibi_f16_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
scores: *const c_void,
slopes: *const c_void,
y: *const c_void,
) -> i32;
/// ALiBi FW, bf16.
pub fn baracuda_kernels_alibi_bf16_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
scores: *const c_void,
slopes: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `alibi_bf16`. Host-side only.
pub fn baracuda_kernels_alibi_bf16_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
scores: *const c_void,
slopes: *const c_void,
y: *const c_void,
) -> i32;
/// ALiBi FW, f64.
pub fn baracuda_kernels_alibi_f64_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
scores: *const c_void,
slopes: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `alibi_f64`. Host-side only.
pub fn baracuda_kernels_alibi_f64_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
scores: *const c_void,
slopes: *const c_void,
y: *const c_void,
) -> i32;
/// ALiBi BW, f32. `da[b, h, i, j] = dy[b, h, i, j]` (pass-through);
/// `dslope[h] = Σ_{b, i, j} dy[b, h, i, j] · (j - i)`. Either `da`
/// or `dslope` may be null to skip; both null is rejected.
pub fn baracuda_kernels_alibi_backward_f32_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
dy: *const c_void,
da: *mut c_void,
dslope: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `alibi_backward_f32`. Host-side only.
pub fn baracuda_kernels_alibi_backward_f32_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
dy: *const c_void,
da: *const c_void,
dslope: *const c_void,
) -> i32;
/// ALiBi BW, f16.
pub fn baracuda_kernels_alibi_backward_f16_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
dy: *const c_void,
da: *mut c_void,
dslope: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `alibi_backward_f16`. Host-side only.
pub fn baracuda_kernels_alibi_backward_f16_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
dy: *const c_void,
da: *const c_void,
dslope: *const c_void,
) -> i32;
/// ALiBi BW, bf16.
pub fn baracuda_kernels_alibi_backward_bf16_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
dy: *const c_void,
da: *mut c_void,
dslope: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `alibi_backward_bf16`. Host-side only.
pub fn baracuda_kernels_alibi_backward_bf16_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
dy: *const c_void,
da: *const c_void,
dslope: *const c_void,
) -> i32;
/// ALiBi BW, f64.
pub fn baracuda_kernels_alibi_backward_f64_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
dy: *const c_void,
da: *mut c_void,
dslope: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `alibi_backward_f64`. Host-side only.
pub fn baracuda_kernels_alibi_backward_f64_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
dy: *const c_void,
da: *const c_void,
dslope: *const c_void,
) -> i32;
// =========================================================================
// Milestone 6.2 — naive Scaled Dot-Product Attention (SDPA). One
// `_run` symbol per dtype per direction; each runs the full
// 3-kernel (FW) / 5-kernel (BW) pipeline internally.
// =========================================================================
/// SDPA FW, f32. Computes `y = softmax(Q·K^T·scale + mask) · V`. The
/// `attn` buffer ([B, H, Q, K]) doubles as the scores intermediate
/// and is overwritten in place with the softmax output (saved for
/// BW). Pass `has_mask = 0` and `mask = nullptr` to skip the mask
/// add. `is_causal = 1` applies an upper-triangular -inf mask
/// inside the scores kernel.
pub fn baracuda_kernels_sdpa_f32_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
has_mask: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
mask: *const c_void,
attn: *mut c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `sdpa_f32`. Host-side only.
pub fn baracuda_kernels_sdpa_f32_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
has_mask: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
mask: *const c_void,
attn: *const c_void,
y: *const c_void,
) -> i32;
/// SDPA FW, f16 (f32 accumulators).
pub fn baracuda_kernels_sdpa_f16_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
has_mask: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
mask: *const c_void,
attn: *mut c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `sdpa_f16`. Host-side only.
pub fn baracuda_kernels_sdpa_f16_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
has_mask: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
mask: *const c_void,
attn: *const c_void,
y: *const c_void,
) -> i32;
/// SDPA FW, bf16 (f32 accumulators).
pub fn baracuda_kernels_sdpa_bf16_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
has_mask: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
mask: *const c_void,
attn: *mut c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `sdpa_bf16`. Host-side only.
pub fn baracuda_kernels_sdpa_bf16_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
has_mask: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
mask: *const c_void,
attn: *const c_void,
y: *const c_void,
) -> i32;
/// SDPA FW, f64.
pub fn baracuda_kernels_sdpa_f64_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
has_mask: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
mask: *const c_void,
attn: *mut c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `sdpa_f64`. Host-side only.
pub fn baracuda_kernels_sdpa_f64_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
has_mask: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
mask: *const c_void,
attn: *const c_void,
y: *const c_void,
) -> i32;
/// SDPA BW, f32. Given the FW-saved `attn` ([B, H, Q, K]), `Q`, `K`,
/// `V`, and upstream `dy`, computes `dQ`, `dK`, `dV`. The
/// `dscores_ws` argument is a caller-allocated [B, H, Q, K] scratch
/// buffer reused as the dattn → dscores intermediate; size matches
/// the FW `attn` tensor.
pub fn baracuda_kernels_sdpa_backward_f32_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
attn: *const c_void,
dy: *const c_void,
dscores_ws: *mut c_void,
dQ: *mut c_void,
dK: *mut c_void,
dV: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `sdpa_backward_f32`. Host-side only.
pub fn baracuda_kernels_sdpa_backward_f32_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
attn: *const c_void,
dy: *const c_void,
dscores_ws: *const c_void,
dQ: *const c_void,
dK: *const c_void,
dV: *const c_void,
) -> i32;
/// SDPA BW, f16.
pub fn baracuda_kernels_sdpa_backward_f16_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
attn: *const c_void,
dy: *const c_void,
dscores_ws: *mut c_void,
dQ: *mut c_void,
dK: *mut c_void,
dV: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `sdpa_backward_f16`. Host-side only.
pub fn baracuda_kernels_sdpa_backward_f16_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
attn: *const c_void,
dy: *const c_void,
dscores_ws: *const c_void,
dQ: *const c_void,
dK: *const c_void,
dV: *const c_void,
) -> i32;
/// SDPA BW, bf16.
pub fn baracuda_kernels_sdpa_backward_bf16_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
attn: *const c_void,
dy: *const c_void,
dscores_ws: *mut c_void,
dQ: *mut c_void,
dK: *mut c_void,
dV: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `sdpa_backward_bf16`. Host-side only.
pub fn baracuda_kernels_sdpa_backward_bf16_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
attn: *const c_void,
dy: *const c_void,
dscores_ws: *const c_void,
dQ: *const c_void,
dK: *const c_void,
dV: *const c_void,
) -> i32;
/// SDPA BW, f64.
pub fn baracuda_kernels_sdpa_backward_f64_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
attn: *const c_void,
dy: *const c_void,
dscores_ws: *mut c_void,
dQ: *mut c_void,
dK: *mut c_void,
dV: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `sdpa_backward_f64`. Host-side only.
pub fn baracuda_kernels_sdpa_backward_f64_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
attn: *const c_void,
dy: *const c_void,
dscores_ws: *const c_void,
dQ: *const c_void,
dK: *const c_void,
dV: *const c_void,
) -> i32;
// ---- SDPA strided FW + BW × 4 dtypes — Phase 14.4 ----
//
// Per-tensor stride arrays: each is `*const i64` of length 3 (one
// per outer dim: batch, heads, seq). The innermost head_dim axis
// is implicitly stride=1, enforced by the Rust plan layer. mask
// and attn (FW) / dscores_ws (BW) stay contig (no stride args).
//
// GQA broadcast: `stride_k[1]` (head axis) may be zero — kernel
// reads the same K row for every Q-head in the group. Same for V.
// BW does NOT support zero strides on K/V (would require atomic
// reduction over Q-head groups); Rust plan rejects.
//
// attn buffer (FW intermediate / BW saved softmax) stays contig
// [B, H, Q, K]. dscores_ws (BW) is similarly contig.
/// SDPA FW strided, f32.
pub fn baracuda_kernels_sdpa_f32_strided_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
stride_q: *const i64,
stride_k: *const i64,
stride_v: *const i64,
stride_mask: *const i64,
stride_y: *const i64,
scale: f32,
is_causal: i32,
has_mask: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
mask: *const c_void,
attn: *mut c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `sdpa_f32_strided`. Host-side only.
pub fn baracuda_kernels_sdpa_f32_strided_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
stride_q: *const i64,
stride_k: *const i64,
stride_v: *const i64,
stride_mask: *const i64,
stride_y: *const i64,
scale: f32,
is_causal: i32,
has_mask: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
mask: *const c_void,
attn: *const c_void,
y: *const c_void,
) -> i32;
/// SDPA FW strided, f16.
pub fn baracuda_kernels_sdpa_f16_strided_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
stride_q: *const i64,
stride_k: *const i64,
stride_v: *const i64,
stride_mask: *const i64,
stride_y: *const i64,
scale: f32,
is_causal: i32,
has_mask: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
mask: *const c_void,
attn: *mut c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `sdpa_f16_strided`. Host-side only.
pub fn baracuda_kernels_sdpa_f16_strided_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
stride_q: *const i64,
stride_k: *const i64,
stride_v: *const i64,
stride_mask: *const i64,
stride_y: *const i64,
scale: f32,
is_causal: i32,
has_mask: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
mask: *const c_void,
attn: *const c_void,
y: *const c_void,
) -> i32;
/// SDPA FW strided, bf16.
pub fn baracuda_kernels_sdpa_bf16_strided_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
stride_q: *const i64,
stride_k: *const i64,
stride_v: *const i64,
stride_mask: *const i64,
stride_y: *const i64,
scale: f32,
is_causal: i32,
has_mask: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
mask: *const c_void,
attn: *mut c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `sdpa_bf16_strided`. Host-side only.
pub fn baracuda_kernels_sdpa_bf16_strided_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
stride_q: *const i64,
stride_k: *const i64,
stride_v: *const i64,
stride_mask: *const i64,
stride_y: *const i64,
scale: f32,
is_causal: i32,
has_mask: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
mask: *const c_void,
attn: *const c_void,
y: *const c_void,
) -> i32;
/// SDPA FW strided, f64.
pub fn baracuda_kernels_sdpa_f64_strided_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
stride_q: *const i64,
stride_k: *const i64,
stride_v: *const i64,
stride_mask: *const i64,
stride_y: *const i64,
scale: f32,
is_causal: i32,
has_mask: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
mask: *const c_void,
attn: *mut c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `sdpa_f64_strided`. Host-side only.
pub fn baracuda_kernels_sdpa_f64_strided_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
stride_q: *const i64,
stride_k: *const i64,
stride_v: *const i64,
stride_mask: *const i64,
stride_y: *const i64,
scale: f32,
is_causal: i32,
has_mask: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
mask: *const c_void,
attn: *const c_void,
y: *const c_void,
) -> i32;
/// SDPA BW strided, f32.
pub fn baracuda_kernels_sdpa_backward_f32_strided_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
stride_q: *const i64,
stride_k: *const i64,
stride_v: *const i64,
stride_dy: *const i64,
stride_dq: *const i64,
stride_dk: *const i64,
stride_dv: *const i64,
scale: f32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
attn: *const c_void,
dy: *const c_void,
dscores_ws: *mut c_void,
dQ: *mut c_void,
dK: *mut c_void,
dV: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `sdpa_backward_f32_strided`. Host-side only.
pub fn baracuda_kernels_sdpa_backward_f32_strided_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
stride_q: *const i64,
stride_k: *const i64,
stride_v: *const i64,
stride_dy: *const i64,
stride_dq: *const i64,
stride_dk: *const i64,
stride_dv: *const i64,
scale: f32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
attn: *const c_void,
dy: *const c_void,
dscores_ws: *const c_void,
dQ: *const c_void,
dK: *const c_void,
dV: *const c_void,
) -> i32;
/// SDPA BW strided, f16.
pub fn baracuda_kernels_sdpa_backward_f16_strided_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
stride_q: *const i64,
stride_k: *const i64,
stride_v: *const i64,
stride_dy: *const i64,
stride_dq: *const i64,
stride_dk: *const i64,
stride_dv: *const i64,
scale: f32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
attn: *const c_void,
dy: *const c_void,
dscores_ws: *mut c_void,
dQ: *mut c_void,
dK: *mut c_void,
dV: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `sdpa_backward_f16_strided`. Host-side only.
pub fn baracuda_kernels_sdpa_backward_f16_strided_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
stride_q: *const i64,
stride_k: *const i64,
stride_v: *const i64,
stride_dy: *const i64,
stride_dq: *const i64,
stride_dk: *const i64,
stride_dv: *const i64,
scale: f32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
attn: *const c_void,
dy: *const c_void,
dscores_ws: *const c_void,
dQ: *const c_void,
dK: *const c_void,
dV: *const c_void,
) -> i32;
/// SDPA BW strided, bf16.
pub fn baracuda_kernels_sdpa_backward_bf16_strided_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
stride_q: *const i64,
stride_k: *const i64,
stride_v: *const i64,
stride_dy: *const i64,
stride_dq: *const i64,
stride_dk: *const i64,
stride_dv: *const i64,
scale: f32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
attn: *const c_void,
dy: *const c_void,
dscores_ws: *mut c_void,
dQ: *mut c_void,
dK: *mut c_void,
dV: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `sdpa_backward_bf16_strided`. Host-side only.
pub fn baracuda_kernels_sdpa_backward_bf16_strided_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
stride_q: *const i64,
stride_k: *const i64,
stride_v: *const i64,
stride_dy: *const i64,
stride_dq: *const i64,
stride_dk: *const i64,
stride_dv: *const i64,
scale: f32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
attn: *const c_void,
dy: *const c_void,
dscores_ws: *const c_void,
dQ: *const c_void,
dK: *const c_void,
dV: *const c_void,
) -> i32;
/// SDPA BW strided, f64.
pub fn baracuda_kernels_sdpa_backward_f64_strided_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
stride_q: *const i64,
stride_k: *const i64,
stride_v: *const i64,
stride_dy: *const i64,
stride_dq: *const i64,
stride_dk: *const i64,
stride_dv: *const i64,
scale: f32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
attn: *const c_void,
dy: *const c_void,
dscores_ws: *mut c_void,
dQ: *mut c_void,
dK: *mut c_void,
dV: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `sdpa_backward_f64_strided`. Host-side only.
pub fn baracuda_kernels_sdpa_backward_f64_strided_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
stride_q: *const i64,
stride_k: *const i64,
stride_v: *const i64,
stride_dy: *const i64,
stride_dq: *const i64,
stride_dk: *const i64,
stride_dv: *const i64,
scale: f32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
attn: *const c_void,
dy: *const c_void,
dscores_ws: *const c_void,
dQ: *const c_void,
dK: *const c_void,
dV: *const c_void,
) -> i32;
// =========================================================================
// Milestone 6.5 — KV-cache append (decoder-inference helper). Each
// launcher fires two device-side copy kernels (K + V) on the same
// stream. Pure copy → bit-exact at every dtype.
//
// Inputs:
// k_new : T[B, H, L_new, D_k]
// v_new : T[B, H, L_new, D_v]
// cache_offsets : i64[B] — per-sample insert offset
// Outputs (modified in place):
// k_cache : T[B, H, L_max, D_k]
// v_cache : T[B, H, L_max, D_v]
// Cells where `cache_offsets[b] + l_new >= L_max` are silently skipped.
// =========================================================================
/// KV-cache append, f32.
pub fn baracuda_kernels_kv_cache_append_f32_run(
batch: i32,
heads: i32,
new_len: i32,
max_cache_len: i32,
d_k: i32,
d_v: i32,
k_new: *const c_void,
v_new: *const c_void,
cache_offsets: *const c_void,
k_cache: *mut c_void,
v_cache: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `kv_cache_append_f32`. Host-side only.
pub fn baracuda_kernels_kv_cache_append_f32_can_implement(
batch: i32,
heads: i32,
new_len: i32,
max_cache_len: i32,
d_k: i32,
d_v: i32,
k_new: *const c_void,
v_new: *const c_void,
cache_offsets: *const c_void,
k_cache: *const c_void,
v_cache: *const c_void,
) -> i32;
/// KV-cache append, f16.
pub fn baracuda_kernels_kv_cache_append_f16_run(
batch: i32,
heads: i32,
new_len: i32,
max_cache_len: i32,
d_k: i32,
d_v: i32,
k_new: *const c_void,
v_new: *const c_void,
cache_offsets: *const c_void,
k_cache: *mut c_void,
v_cache: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `kv_cache_append_f16`. Host-side only.
pub fn baracuda_kernels_kv_cache_append_f16_can_implement(
batch: i32,
heads: i32,
new_len: i32,
max_cache_len: i32,
d_k: i32,
d_v: i32,
k_new: *const c_void,
v_new: *const c_void,
cache_offsets: *const c_void,
k_cache: *const c_void,
v_cache: *const c_void,
) -> i32;
/// KV-cache append, bf16.
pub fn baracuda_kernels_kv_cache_append_bf16_run(
batch: i32,
heads: i32,
new_len: i32,
max_cache_len: i32,
d_k: i32,
d_v: i32,
k_new: *const c_void,
v_new: *const c_void,
cache_offsets: *const c_void,
k_cache: *mut c_void,
v_cache: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `kv_cache_append_bf16`. Host-side only.
pub fn baracuda_kernels_kv_cache_append_bf16_can_implement(
batch: i32,
heads: i32,
new_len: i32,
max_cache_len: i32,
d_k: i32,
d_v: i32,
k_new: *const c_void,
v_new: *const c_void,
cache_offsets: *const c_void,
k_cache: *const c_void,
v_cache: *const c_void,
) -> i32;
/// KV-cache append, f64.
pub fn baracuda_kernels_kv_cache_append_f64_run(
batch: i32,
heads: i32,
new_len: i32,
max_cache_len: i32,
d_k: i32,
d_v: i32,
k_new: *const c_void,
v_new: *const c_void,
cache_offsets: *const c_void,
k_cache: *mut c_void,
v_cache: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `kv_cache_append_f64`. Host-side only.
pub fn baracuda_kernels_kv_cache_append_f64_can_implement(
batch: i32,
heads: i32,
new_len: i32,
max_cache_len: i32,
d_k: i32,
d_v: i32,
k_new: *const c_void,
v_new: *const c_void,
cache_offsets: *const c_void,
k_cache: *const c_void,
v_cache: *const c_void,
) -> i32;
// =========================================================================
// Milestone 6.6 — Flash Attention SDPA (Tri Dao 2022). Tiled fused
// online-softmax kernel that avoids materializing the full
// `[B, H, Q, K]` attention matrix; saves a small `lse: [B, H, Q]`
// log-sum-exp tensor for the BW pass instead. BW is a deterministic
// 3-kernel pipeline (D = rowsum(y ⊙ dy), then dQ per q-block, then
// dK/dV per k-block — each output cell written by exactly one block,
// no atomicAdd). Trailblazer constraints: Br = Bc = 64,
// d_k = d_v ≤ 128.
// =========================================================================
/// Flash SDPA FW, f32. Computes `y = softmax(Q·K^T·scale) · V` via
/// tiled fused online softmax. Optional upper-triangular causal mask
/// (`is_causal = 1`); explicit additive mask is not supported in the
/// trailblazer. Writes `y: [B, H, Q, D_v]` and the saved
/// `lse: [B, H, Q]` log-sum-exp tensor that BW consumes.
pub fn baracuda_kernels_flash_sdpa_f32_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
y: *mut c_void,
lse: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `flash_sdpa_f32`. Host-side only.
pub fn baracuda_kernels_flash_sdpa_f32_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
y: *const c_void,
lse: *const c_void,
) -> i32;
/// Flash SDPA FW, f16 (f32 accumulators).
pub fn baracuda_kernels_flash_sdpa_f16_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
y: *mut c_void,
lse: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `flash_sdpa_f16`. Host-side only.
pub fn baracuda_kernels_flash_sdpa_f16_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
y: *const c_void,
lse: *const c_void,
) -> i32;
/// Flash SDPA FW, bf16 (f32 accumulators).
pub fn baracuda_kernels_flash_sdpa_bf16_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
y: *mut c_void,
lse: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `flash_sdpa_bf16`. Host-side only.
pub fn baracuda_kernels_flash_sdpa_bf16_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
y: *const c_void,
lse: *const c_void,
) -> i32;
/// Flash SDPA FW, f64.
pub fn baracuda_kernels_flash_sdpa_f64_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
y: *mut c_void,
lse: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `flash_sdpa_f64`. Host-side only.
pub fn baracuda_kernels_flash_sdpa_f64_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
y: *const c_void,
lse: *const c_void,
) -> i32;
// =========================================================================
// Phase 73 follow-up — FlashDecoding (split-K parallel attention
// decode for seq_q = 1). Closes the perf gap FA2 leaves at the
// decode regime where seq_q is too short to fill FA2's q-tile rows.
// Two-kernel pipeline (split + combine) is implemented in
// `kernels/include/baracuda_flash_decoding.cuh`; this FFI surfaces
// the per-dtype launcher symbols.
//
// Strides are in element units (matching the rest of baracuda's
// strided FFI). GQA is expressed via stride[1] = 0 on K/V.
// =========================================================================
/// FlashDecoding FW, f16 (f32 accumulators). seq_q = 1; split-K
/// over chunks of 256 K-rows each, combined via a second kernel.
pub fn baracuda_kernels_flash_decoding_f16_run(
q: *const c_void,
k: *const c_void,
v: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
batch: i32,
heads: i32,
num_kv_heads: i32,
k_len: i32,
head_dim: i32,
q_b_stride: i64,
q_h_stride: i64,
k_b_stride: i64,
k_h_stride: i64,
k_seq_stride: i64,
v_b_stride: i64,
v_h_stride: i64,
v_seq_stride: i64,
y_b_stride: i64,
y_h_stride: i64,
scale: f32,
stream_ptr: *mut c_void,
) -> i32;
/// Implementability check for `flash_decoding_f16`. Host-side only.
pub fn baracuda_kernels_flash_decoding_f16_can_implement(
batch: i32,
heads: i32,
num_kv_heads: i32,
k_len: i32,
head_dim: i32,
) -> i32;
/// Workspace requirement for `flash_decoding_f16` in bytes.
pub fn baracuda_kernels_flash_decoding_f16_workspace_bytes(
batch: i32,
heads: i32,
k_len: i32,
head_dim: i32,
) -> usize;
/// FlashDecoding FW, bf16 (f32 accumulators).
pub fn baracuda_kernels_flash_decoding_bf16_run(
q: *const c_void,
k: *const c_void,
v: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
batch: i32,
heads: i32,
num_kv_heads: i32,
k_len: i32,
head_dim: i32,
q_b_stride: i64,
q_h_stride: i64,
k_b_stride: i64,
k_h_stride: i64,
k_seq_stride: i64,
v_b_stride: i64,
v_h_stride: i64,
v_seq_stride: i64,
y_b_stride: i64,
y_h_stride: i64,
scale: f32,
stream_ptr: *mut c_void,
) -> i32;
/// Implementability check for `flash_decoding_bf16`. Host-side only.
pub fn baracuda_kernels_flash_decoding_bf16_can_implement(
batch: i32,
heads: i32,
num_kv_heads: i32,
k_len: i32,
head_dim: i32,
) -> i32;
/// Workspace requirement for `flash_decoding_bf16` in bytes.
pub fn baracuda_kernels_flash_decoding_bf16_workspace_bytes(
batch: i32,
heads: i32,
k_len: i32,
head_dim: i32,
) -> usize;
// =========================================================================
// Phase 51 — arbitrary additive-mask attention FW.
//
// Same online-softmax algorithm as the `flash_sdpa_*_run` family with
// an additional `mask: f32[B, H, Q, K]` additive bias applied to
// S = Q·K^T·scale before the row max/softmax. Mask is **always f32**
// regardless of element dtype — additive-bias precision is decoupled
// from QKV precision and keeps the FFI surface compact. Use
// `-INFINITY` cells in the mask to suppress exactly.
//
// Tier-1 dtype set: {f32, f16, bf16, f64}. FW only (BW deferred to
// Tier 2 — same as the FA2 vendor's deferral).
// =========================================================================
/// Arbitrary additive-mask SDPA FW, f32. `mask` shape `[B, H, Q, K]`
/// f32, applied as an additive bias on the score tile before softmax.
pub fn baracuda_kernels_sdpa_f32_arbmask_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
mask: *const c_void,
y: *mut c_void,
lse: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Arbitrary additive-mask SDPA FW, f16 (f32 accumulators).
pub fn baracuda_kernels_sdpa_f16_arbmask_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
mask: *const c_void,
y: *mut c_void,
lse: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Arbitrary additive-mask SDPA FW, bf16 (f32 accumulators).
pub fn baracuda_kernels_sdpa_bf16_arbmask_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
mask: *const c_void,
y: *mut c_void,
lse: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Arbitrary additive-mask SDPA FW, f64.
pub fn baracuda_kernels_sdpa_f64_arbmask_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
mask: *const c_void,
y: *mut c_void,
lse: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Arbitrary-mask SDPA host-side can-implement, f32.
pub fn baracuda_kernels_sdpa_f32_arbmask_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
is_causal: i32,
) -> i32;
/// Arbitrary-mask SDPA host-side can-implement, f16.
pub fn baracuda_kernels_sdpa_f16_arbmask_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
is_causal: i32,
) -> i32;
/// Arbitrary-mask SDPA host-side can-implement, bf16.
pub fn baracuda_kernels_sdpa_bf16_arbmask_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
is_causal: i32,
) -> i32;
/// Arbitrary-mask SDPA host-side can-implement, f64.
pub fn baracuda_kernels_sdpa_f64_arbmask_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
is_causal: i32,
) -> i32;
/// Flash SDPA BW, f32. Given the FW-saved `y`, `lse`, plus upstream
/// `dy`, computes `dQ`, `dK`, `dV`. The `d_ws` argument is a
/// caller-allocated `[B, H, Q]` scratch buffer (overwritten with the
/// per-row `D = rowsum(y ⊙ dy)` intermediate; element type matches T).
pub fn baracuda_kernels_flash_sdpa_backward_f32_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
y: *const c_void,
lse: *const c_void,
dy: *const c_void,
d_ws: *mut c_void,
dQ: *mut c_void,
dK: *mut c_void,
dV: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `flash_sdpa_backward_f32`. Host-side only.
pub fn baracuda_kernels_flash_sdpa_backward_f32_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
y: *const c_void,
lse: *const c_void,
dy: *const c_void,
d_ws: *const c_void,
dQ: *const c_void,
dK: *const c_void,
dV: *const c_void,
) -> i32;
/// Flash SDPA BW, f16.
pub fn baracuda_kernels_flash_sdpa_backward_f16_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
y: *const c_void,
lse: *const c_void,
dy: *const c_void,
d_ws: *mut c_void,
dQ: *mut c_void,
dK: *mut c_void,
dV: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `flash_sdpa_backward_f16`. Host-side only.
pub fn baracuda_kernels_flash_sdpa_backward_f16_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
y: *const c_void,
lse: *const c_void,
dy: *const c_void,
d_ws: *const c_void,
dQ: *const c_void,
dK: *const c_void,
dV: *const c_void,
) -> i32;
/// Flash SDPA BW, bf16.
pub fn baracuda_kernels_flash_sdpa_backward_bf16_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
y: *const c_void,
lse: *const c_void,
dy: *const c_void,
d_ws: *mut c_void,
dQ: *mut c_void,
dK: *mut c_void,
dV: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `flash_sdpa_backward_bf16`. Host-side only.
pub fn baracuda_kernels_flash_sdpa_backward_bf16_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
y: *const c_void,
lse: *const c_void,
dy: *const c_void,
d_ws: *const c_void,
dQ: *const c_void,
dK: *const c_void,
dV: *const c_void,
) -> i32;
/// Flash SDPA BW, f64.
pub fn baracuda_kernels_flash_sdpa_backward_f64_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
y: *const c_void,
lse: *const c_void,
dy: *const c_void,
d_ws: *mut c_void,
dQ: *mut c_void,
dK: *mut c_void,
dV: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `flash_sdpa_backward_f64`. Host-side only.
pub fn baracuda_kernels_flash_sdpa_backward_f64_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
y: *const c_void,
lse: *const c_void,
dy: *const c_void,
d_ws: *const c_void,
dQ: *const c_void,
dK: *const c_void,
dV: *const c_void,
) -> i32;
}
// ============================================================================
// Phase 10 Milestone 10.3 — Flash Attention SDPA FW, sm_89 (Ada) sibling
// ============================================================================
//
// Same FW signature as the sm_80 baseline (`baracuda_kernels_flash_sdpa_*_run`)
// — the sm_89 variant is purely a data-movement optimization (`cp.async`
// double-buffered K/V loads + 256-thread block). f16 + bf16 only; f32 /
// f64 stay on the sm_80 baseline. BW is shared (the existing sm_80 BW
// kernels run forward-compat on Ada — there's no Ada-specific BW
// optimization in this milestone).
#[cfg(feature = "sm89")]
unsafe extern "C" {
/// Flash SDPA FW, f16 (f32 accumulators), sm_89 specialization with
/// `cp.async` K/V double-buffer.
pub fn baracuda_kernels_flash_sdpa_sm89_f16_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
y: *mut c_void,
lse: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `flash_sdpa_sm89_f16`. Host-side only.
pub fn baracuda_kernels_flash_sdpa_sm89_f16_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
y: *const c_void,
lse: *const c_void,
) -> i32;
/// Flash SDPA FW, bf16 (f32 accumulators), sm_89 specialization with
/// `cp.async` K/V double-buffer.
pub fn baracuda_kernels_flash_sdpa_sm89_bf16_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
y: *mut c_void,
lse: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `flash_sdpa_sm89_bf16`. Host-side only.
pub fn baracuda_kernels_flash_sdpa_sm89_bf16_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
y: *const c_void,
lse: *const c_void,
) -> i32;
// Phase 17.1 — strided FW siblings.
//
// Same shape semantics as the Phase 14.4 naive SDPA strided FFI:
// stride_q / stride_k / stride_v / stride_y are `*const i64` length 3
// (one per outer dim: batch, heads, seq). The innermost head_dim axis
// is implicitly stride=1 (enforced by the Rust plan layer).
//
// GQA broadcast: pass `stride_k[1] == 0` (or `stride_v[1] == 0`) and
// multiple Q-heads in the same kv-head group dereference the same
// K/V row.
//
// `lse` stays contig `[B, H, Q]` (BW path routes through sm_80
// baseline). Mask is not supported on this strided path — masked
// callers must use the non-strided sm_89 plan (when contig) or the
// sm_80 naive-SDPA strided plan.
/// Flash SDPA FW, f16, sm_89 strided sibling.
pub fn baracuda_kernels_flash_sdpa_sm89_f16_strided_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
stride_q: *const i64,
stride_k: *const i64,
stride_v: *const i64,
stride_y: *const i64,
scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
y: *mut c_void,
lse: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `flash_sdpa_sm89_f16_strided`. Host-side only.
pub fn baracuda_kernels_flash_sdpa_sm89_f16_strided_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
stride_q: *const i64,
stride_k: *const i64,
stride_v: *const i64,
stride_y: *const i64,
scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
y: *const c_void,
lse: *const c_void,
) -> i32;
/// Flash SDPA FW, bf16, sm_89 strided sibling.
pub fn baracuda_kernels_flash_sdpa_sm89_bf16_strided_run(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
stride_q: *const i64,
stride_k: *const i64,
stride_v: *const i64,
stride_y: *const i64,
scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
y: *mut c_void,
lse: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `flash_sdpa_sm89_bf16_strided`. Host-side only.
pub fn baracuda_kernels_flash_sdpa_sm89_bf16_strided_can_implement(
batch: i32,
heads: i32,
q_len: i32,
k_len: i32,
d_k: i32,
d_v: i32,
stride_q: *const i64,
stride_k: *const i64,
stride_v: *const i64,
stride_y: *const i64,
scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
y: *const c_void,
lse: *const c_void,
) -> i32;
}
// ============================================================================
// Phase 42 — Dao-AILab FlashAttention v2 (vendored v2.8.3, BSD-3-Clause)
// ============================================================================
//
// Backend-choice fast path that pairs with baracuda's bespoke `FlashSdpaPlan`
// for long-context regimes where FA2's tiling wins. Tier-1 scope:
// * Forward only (BW deferred to Tier 2).
// * head_dim = 128 only (Tier 3 adds 32/64/96/192/256).
// * f16 + bf16.
// * Dense layout — NO varlen (`cu_seqlens_q/k`), NO GQA
// (`num_heads != num_heads_k` rejected with status 3), NO dropout,
// NO ALiBi, NO rotary, NO paged KV cache.
//
// Tensor layout (contiguous row-major, identical to bespoke `FlashSdpaPlan`):
// * Q : `[batch, num_heads, seq_q, head_dim]`
// * K : `[batch, num_heads_k, seq_k, head_dim]`
// * V : `[batch, num_heads_k, seq_k, head_dim]`
// * out:`[batch, num_heads, seq_q, head_dim]`
//
// **`softmax_lse` is always f32** regardless of the element dtype —
// this differs from the bespoke `FlashSdpaPlan` (where lse matches T).
// FA2 internally accumulates softmax in f32 and writes the LSE tensor in
// f32 to preserve range across long sequences. The plan layer adapts.
//
// Symbols are gated behind the `fa2` cargo feature — compiling FA2's
// CUTLASS-heavy templates adds significant nvcc build time. Off by
// default; enable when you want the dispatch heuristic to be able to
// pick FA2.
#[cfg(feature = "fa2")]
unsafe extern "C" {
/// FA2 forward, f16 (f32 LSE). This is the FA2 forward trailblazer
/// — its safety + LSE saved-tensor contract carries over to the
/// bf16 sibling and the v2 forward entry points.
///
/// # LSE saved-tensor contract (Phase 63)
///
/// `softmax_lse` is a load-bearing OUTPUT of this kernel and the
/// load-bearing INPUT to the corresponding
/// [`baracuda_kernels_fa2_sdpa_backward_f16_run`]. Callers
/// implementing differentiable attention (e.g. an autograd
/// framework's `FlashAttn` op) MUST:
///
/// 1. Pre-allocate `softmax_lse` with
/// [`baracuda_kernels_fa2_sdpa_lse_size`]`(batch, num_heads, seq_q)`
/// f32 elements (multiply by 4 for bytes).
/// 2. Pass the buffer as `softmax_lse` here.
/// 3. **Save the buffer** alongside the operand tensors (`q`,
/// `k`, `v`, `out`) for the BW pass — same exact pointer.
/// 4. Pass the saved buffer as `lse` to the BW launcher.
///
/// LSE is always f32 regardless of the operand dtype (f16 / bf16
/// operands both produce f32 LSE). FA2 internally accumulates
/// softmax in f32 to preserve range across long sequences. Reusing
/// the bespoke `FlashSdpaPlan`'s typed-T LSE on the FA2 BW path
/// is INVALID and rejected at the safe-plan layer.
pub fn baracuda_kernels_fa2_sdpa_f16_run(
batch: i32,
num_heads: i32,
num_heads_k: i32,
seq_q: i32,
seq_k: i32,
head_dim: i32,
softmax_scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
out: *mut c_void,
softmax_lse: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// FA2 forward, bf16 (f32 LSE).
pub fn baracuda_kernels_fa2_sdpa_bf16_run(
batch: i32,
num_heads: i32,
num_heads_k: i32,
seq_q: i32,
seq_k: i32,
head_dim: i32,
softmax_scale: f32,
is_causal: i32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
out: *mut c_void,
softmax_lse: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// FA2 can-implement check, f16. Returns 0 / 2 / 3 in the same
/// convention as the corresponding `_run` symbol; the safe-plan
/// layer calls this from `FlashSdpaPlan::can_implement` to validate
/// arguments before a launch.
pub fn baracuda_kernels_fa2_sdpa_f16_can_implement(
batch: i32,
num_heads: i32,
num_heads_k: i32,
seq_q: i32,
seq_k: i32,
head_dim: i32,
is_causal: i32,
) -> i32;
/// FA2 can-implement check, bf16.
pub fn baracuda_kernels_fa2_sdpa_bf16_can_implement(
batch: i32,
num_heads: i32,
num_heads_k: i32,
seq_q: i32,
seq_k: i32,
head_dim: i32,
is_causal: i32,
) -> i32;
}
// ============================================================================
// Phase 59a — FA2 FW expansion: head_dim fanout + GQA + ALiBi + sliding
// window + softcap
// ============================================================================
//
// Phase 42 (Tier-1) shipped head_dim=128 only, no GQA, no ALiBi, no
// sliding window, no softcap. Phase 59a closes Fuel's "still needs
// upstream FA2" gap on the **forward** side:
//
// * **head_dim fanout** — extends the existing `..._run` /
// `..._can_implement` symbols to accept any of FA2 v2.8.3's
// supported head_dims: {32, 64, 96, 128, 192, 256}. The launcher
// dispatches via a runtime switch on the `head_dim` param. Upstream
// FA2 v2.8.3 does NOT ship head_dims 160, 224, or 512 — those are
// permanently Tier-3-deferred (no upstream sources to vendor).
// **No new FFI symbols** — the original v1 `..._run` signature was
// already parameterized on head_dim.
// * **GQA** — the original v1 signature already exposed `num_heads_k`
// as a distinct parameter; Phase 42's launcher rejected
// `num_heads_k != num_heads` outright. Phase 59a's launcher accepts
// any `num_heads_k` where `num_heads % num_heads_k == 0` (FA2's
// `h_h_k_ratio` mechanism handles the broadcast in-kernel).
// **No new FFI symbols** — backwards-compatible loosening.
// * **ALiBi + sliding window + softcap** — these need NEW input
// params. Exposed as `..._run_v2` companion symbols (one per dtype)
// that take the full Phase 59a feature set:
//
// - `alibi_slopes_ptr` — `f32*` device pointer or null. Shape is
// either `[num_heads]` (per-head broadcast across batches) or
// `[batch, num_heads]` (per-batch-per-head). `alibi_batch_stride`
// selects: 0 for the `[num_heads]` layout, `num_heads` for the
// `[batch, num_heads]` layout. FA2's kernel reads the ALiBi
// slope for the active (b, h) and applies an arithmetic bias
// to the score tile before softmax.
// - `window_size_left` / `window_size_right` — i32 sliding window
// bounds. `-1` disables on that side (matches FA2's convention).
// Setting `is_causal=1` forces `window_size_right=0` regardless
// of caller input (causal == "no right context").
// - `softcap` — f32 tanh-cap value. `0.0` disables (matches FA2's
// convention). When > 0, applied as `scores = softcap * tanh(scores / softcap)`
// before softmax (used by Gemma-2 et al.).
//
// The v1 symbols are kept untouched for callers that don't need the
// new params — they internally route through the same launcher with
// ALiBi/sliding/softcap set to their disabled defaults.
#[cfg(feature = "fa2")]
unsafe extern "C" {
/// FA2 forward, f16 — Phase 59a extended signature.
///
/// Adds ALiBi / sliding window / softcap on top of the Phase 42
/// `..._run` symbol. Pass `alibi_slopes_ptr=null`, `window_size_left=-1`,
/// `window_size_right=-1`, `softcap=0.0` to behave identically to v1.
///
/// `head_dim` must be in {32, 64, 96, 128, 192, 256}. GQA requires
/// `num_heads % num_heads_k == 0`. `alibi_batch_stride` is 0 for
/// the `[num_heads]` ALiBi layout, `num_heads` for `[batch, num_heads]`.
pub fn baracuda_kernels_fa2_sdpa_f16_run_v2(
batch: i32,
num_heads: i32,
num_heads_k: i32,
seq_q: i32,
seq_k: i32,
head_dim: i32,
softmax_scale: f32,
is_causal: i32,
// Phase 59a additions
alibi_slopes_ptr: *const c_void,
alibi_batch_stride: i32,
window_size_left: i32,
window_size_right: i32,
softcap: f32,
// Original args follow
q: *const c_void,
k: *const c_void,
v: *const c_void,
out: *mut c_void,
softmax_lse: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// FA2 forward, bf16 — Phase 59a extended signature. See
/// `..._f16_run_v2`.
pub fn baracuda_kernels_fa2_sdpa_bf16_run_v2(
batch: i32,
num_heads: i32,
num_heads_k: i32,
seq_q: i32,
seq_k: i32,
head_dim: i32,
softmax_scale: f32,
is_causal: i32,
alibi_slopes_ptr: *const c_void,
alibi_batch_stride: i32,
window_size_left: i32,
window_size_right: i32,
softcap: f32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
out: *mut c_void,
softmax_lse: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// FA2 v2 can-implement check, f16. Validates the Phase 59a-extended
/// head_dim set + GQA divisibility + ALiBi/sliding/softcap interactions.
pub fn baracuda_kernels_fa2_sdpa_f16_can_implement_v2(
batch: i32,
num_heads: i32,
num_heads_k: i32,
seq_q: i32,
seq_k: i32,
head_dim: i32,
is_causal: i32,
window_size_left: i32,
window_size_right: i32,
softcap: f32,
) -> i32;
/// FA2 v2 can-implement check, bf16. See `..._f16_can_implement_v2`.
pub fn baracuda_kernels_fa2_sdpa_bf16_can_implement_v2(
batch: i32,
num_heads: i32,
num_heads_k: i32,
seq_q: i32,
seq_k: i32,
head_dim: i32,
is_causal: i32,
window_size_left: i32,
window_size_right: i32,
softcap: f32,
) -> i32;
/// Dense FA2 forward LSE size in **f32 elements**:
/// `batch * num_heads * seq_q`. Caller multiplies by 4 for bytes.
/// **Phase 63** — sibling of [`baracuda_kernels_fa2_sdpa_varlen_lse_size`].
///
/// The LSE buffer is the load-bearing saved-tensor input for the BW
/// pass (see `baracuda_kernels_fa2_sdpa_backward_<dt>_run`'s `lse`
/// parameter). Pre-allocate via this helper, pass the same buffer
/// to FW (`softmax_lse` arg) and BW (`lse` arg). LSE is always
/// f32 regardless of operand dtype — see the "LSE saved-tensor
/// contract" section on `baracuda_kernels_fa2_sdpa_f16_run`.
///
/// Returns `0` if any dimension is non-positive (call is then a no-op
/// the safe-plan layer would reject up-front).
pub fn baracuda_kernels_fa2_sdpa_lse_size(
batch: i32,
num_heads: i32,
seq_q: i32,
) -> usize;
}
// ============================================================================
// Phase 43 — AndreSlavescu/mHC.cu HyperConnection family (vendored, MIT)
// ============================================================================
//
// Manifold-Constrained Hyper-Connections — a learned residual-stream
// mixing op from DeepSeek-AI's mHC paper, with the unofficial CUDA
// implementation by Andre Slavescu vendored under `vendor/mhc/`.
//
// Tier-1 scope:
// * Static-H forward only (dynamic-H FW + BW deferred).
// * bf16 weights / f32 activations only (upstream's `floatX` is
// hardcoded to nv_bfloat16; f16 / f32 paths require additional
// convert kernels and are deferred).
// * (B, n, C) tuple constrained at `create` time — handle is
// dimensioned-once, reused across calls. n <= 32.
//
// Memory contract:
// * Stateful `MHCLayer*` opaque handle returned by `create`; pass
// to `run` and `destroy`. The handle owns ~B*n*C*sizeof(float)
// bytes of GPU scratch — caller pays alloc cost once.
// * Caller-supplied `stream` is patched in per call (the upstream
// `MHCLayer::stream` field is restored after the launch returns).
// * `workspace` argument unused at present — internal scratch lives
// in the handle. Reserved for future revisions that might surface
// the dynamic-H projection workspace through this API.
//
// Layout contract:
// * x_expanded: [B, n, C] f32, row-major contiguous.
// * rmsnorm_weight: [C] bf16.
// * H_pre: [n] f32 (pre-sigmoid logits).
// * H_post: [n] f32 (pre-sigmoid logits; output gets a 2x
// scale baked in by the kernel).
// * H_res: [n, n] f32 (pre-Sinkhorn-Knopp mixing matrix).
// * out: [B, n, C] f32, row-major contiguous.
//
// Symbols gated behind the `mhc` cargo feature.
#[cfg(feature = "mhc")]
unsafe extern "C" {
/// Create an mHC static-H layer handle. Allocates internal GPU
/// scratch. Returns nullptr on failure (invalid args or
/// allocation failure).
///
/// `sinkhorn_iters` — typically 20. `eps` — typically 1e-5.
/// `n` must be in `1..=32`.
pub fn baracuda_kernels_mhc_layer_static_bf16_create(
b: i32,
c: i32,
n: i32,
sinkhorn_iters: i32,
eps: f32,
) -> *mut c_void;
/// Destroy an mHC layer handle returned from `create`. Safe to
/// pass nullptr.
pub fn baracuda_kernels_mhc_layer_static_bf16_destroy(handle: *mut c_void);
/// Forward static-H launch. See module-level docstring for the
/// layout / shape contract.
pub fn baracuda_kernels_mhc_layer_static_bf16_run(
handle: *mut c_void,
x_expanded: *const c_void,
rmsnorm_weight: *const c_void,
h_pre: *const c_void,
h_post: *const c_void,
h_res: *const c_void,
out: *mut c_void,
b: i32,
c: i32,
n: i32,
workspace: *mut c_void,
workspace_bytes: u64,
stream: *mut c_void,
) -> i32;
/// Pure-host validation. Returns 0 for supported, 2 for
/// invalid_arg, 3 for unsupported.
pub fn baracuda_kernels_mhc_layer_static_bf16_can_implement(
b: i32,
c: i32,
n: i32,
) -> i32;
}
// ============================================================================
// Phase 46 — FlashInfer cherry-pick (Apache-2.0)
// ============================================================================
//
// Three families: paged-KV decode + append, sort-free top-K/top-P/min-P
// sampling, cascade attention LSE merge. All gated behind the
// `flashinfer` cargo feature. See per-launcher .cu files for the
// per-symbol caller contract.
#[cfg(feature = "flashinfer")]
unsafe extern "C" {
// Paged KV-cache append (decode-time, 1 token per request).
pub fn baracuda_kernels_flashinfer_paged_kv_append_decode_f16_run(
batch_size: i32, page_size: i32, num_heads: i32, head_dim: i32,
k_data: *mut c_void, v_data: *mut c_void,
indices: *mut c_void, indptr: *mut c_void, last_page_len: *mut c_void,
key: *const c_void, value: *const c_void, stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_paged_kv_append_decode_f16_can_implement(
batch_size: i32, page_size: i32, num_heads: i32, head_dim: i32,
) -> i32;
pub fn baracuda_kernels_flashinfer_paged_kv_append_decode_bf16_run(
batch_size: i32, page_size: i32, num_heads: i32, head_dim: i32,
k_data: *mut c_void, v_data: *mut c_void,
indices: *mut c_void, indptr: *mut c_void, last_page_len: *mut c_void,
key: *const c_void, value: *const c_void, stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_paged_kv_append_decode_bf16_can_implement(
batch_size: i32, page_size: i32, num_heads: i32, head_dim: i32,
) -> i32;
pub fn baracuda_kernels_flashinfer_paged_kv_append_decode_f32_run(
batch_size: i32, page_size: i32, num_heads: i32, head_dim: i32,
k_data: *mut c_void, v_data: *mut c_void,
indices: *mut c_void, indptr: *mut c_void, last_page_len: *mut c_void,
key: *const c_void, value: *const c_void, stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_paged_kv_append_decode_f32_can_implement(
batch_size: i32, page_size: i32, num_heads: i32, head_dim: i32,
) -> i32;
pub fn baracuda_kernels_flashinfer_paged_kv_append_decode_can_implement(
batch_size: i32, page_size: i32, num_heads: i32, head_dim: i32,
) -> i32;
// Batched paged-KV decode.
pub fn baracuda_kernels_flashinfer_paged_decode_workspace_size(batch_size: i32) -> usize;
pub fn baracuda_kernels_flashinfer_paged_decode_f16_run(
batch_size: i32, page_size: i32, head_dim: i32,
num_qo_heads: i32, num_kv_heads: i32, sm_scale: f32,
k_data: *mut c_void, v_data: *mut c_void,
indices: *mut c_void, indptr: *mut c_void, last_page_len: *mut c_void,
q: *const c_void, o: *mut c_void, lse: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_paged_decode_f16_can_implement(
batch_size: i32, page_size: i32, head_dim: i32, num_qo_heads: i32, num_kv_heads: i32, sm_scale: f32,
) -> i32;
pub fn baracuda_kernels_flashinfer_paged_decode_bf16_run(
batch_size: i32, page_size: i32, head_dim: i32,
num_qo_heads: i32, num_kv_heads: i32, sm_scale: f32,
k_data: *mut c_void, v_data: *mut c_void,
indices: *mut c_void, indptr: *mut c_void, last_page_len: *mut c_void,
q: *const c_void, o: *mut c_void, lse: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_paged_decode_bf16_can_implement(
batch_size: i32, page_size: i32, head_dim: i32, num_qo_heads: i32, num_kv_heads: i32, sm_scale: f32,
) -> i32;
pub fn baracuda_kernels_flashinfer_paged_decode_f32_run(
batch_size: i32, page_size: i32, head_dim: i32,
num_qo_heads: i32, num_kv_heads: i32, sm_scale: f32,
k_data: *mut c_void, v_data: *mut c_void,
indices: *mut c_void, indptr: *mut c_void, last_page_len: *mut c_void,
q: *const c_void, o: *mut c_void, lse: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_paged_decode_f32_can_implement(
batch_size: i32, page_size: i32, head_dim: i32, num_qo_heads: i32, num_kv_heads: i32, sm_scale: f32,
) -> i32;
pub fn baracuda_kernels_flashinfer_paged_decode_can_implement(
batch_size: i32, page_size: i32, head_dim: i32,
num_qo_heads: i32, num_kv_heads: i32,
) -> i32;
// Batched paged-KV prefill (Phase 66 Tier 2). Ragged q via q_indptr;
// f16/bf16; causal flag; `enable_split` opts into KV-split parallelism.
// Workspace allocated internally (synchronous).
pub fn baracuda_kernels_flashinfer_paged_prefill_f16_run(
batch_size: i32, total_num_rows: i32, page_size: i32, head_dim: i32,
num_qo_heads: i32, num_kv_heads: i32, sm_scale: f32, causal: i32, enable_split: i32,
k_data: *mut c_void, v_data: *mut c_void,
kv_indices: *mut c_void, kv_indptr: *mut c_void, last_page_len: *mut c_void,
q: *const c_void, q_indptr: *mut c_void, o: *mut c_void, lse: *mut c_void,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_paged_prefill_f16_can_implement(
batch_size: i32, total_num_rows: i32, page_size: i32, head_dim: i32, num_qo_heads: i32, num_kv_heads: i32, sm_scale: f32, causal: i32, enable_split: i32,
) -> i32;
pub fn baracuda_kernels_flashinfer_paged_prefill_bf16_run(
batch_size: i32, total_num_rows: i32, page_size: i32, head_dim: i32,
num_qo_heads: i32, num_kv_heads: i32, sm_scale: f32, causal: i32, enable_split: i32,
k_data: *mut c_void, v_data: *mut c_void,
kv_indices: *mut c_void, kv_indptr: *mut c_void, last_page_len: *mut c_void,
q: *const c_void, q_indptr: *mut c_void, o: *mut c_void, lse: *mut c_void,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_paged_prefill_bf16_can_implement(
batch_size: i32, total_num_rows: i32, page_size: i32, head_dim: i32, num_qo_heads: i32, num_kv_heads: i32, sm_scale: f32, causal: i32, enable_split: i32,
) -> i32;
pub fn baracuda_kernels_flashinfer_paged_prefill_can_implement(
batch_size: i32, total_num_rows: i32, page_size: i32, head_dim: i32,
num_qo_heads: i32, num_kv_heads: i32,
) -> i32;
// Batched RAGGED-KV prefill (Phase 66 Tier 2). K/V contiguous via
// kv_indptr (no page table). f16/bf16; causal + enable_split flags.
pub fn baracuda_kernels_flashinfer_ragged_prefill_f16_run(
batch_size: i32, total_num_rows: i32, total_kv_rows: i32, head_dim: i32,
num_qo_heads: i32, num_kv_heads: i32, sm_scale: f32, causal: i32, enable_split: i32,
k_data: *const c_void, v_data: *const c_void,
kv_indptr: *mut c_void, q: *const c_void, q_indptr: *mut c_void,
o: *mut c_void, lse: *mut c_void, stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_ragged_prefill_f16_can_implement(
batch_size: i32, total_num_rows: i32, total_kv_rows: i32, head_dim: i32, num_qo_heads: i32, num_kv_heads: i32, sm_scale: f32, causal: i32,
) -> i32;
pub fn baracuda_kernels_flashinfer_ragged_prefill_bf16_run(
batch_size: i32, total_num_rows: i32, total_kv_rows: i32, head_dim: i32,
num_qo_heads: i32, num_kv_heads: i32, sm_scale: f32, causal: i32, enable_split: i32,
k_data: *const c_void, v_data: *const c_void,
kv_indptr: *mut c_void, q: *const c_void, q_indptr: *mut c_void,
o: *mut c_void, lse: *mut c_void, stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_ragged_prefill_bf16_can_implement(
batch_size: i32, total_num_rows: i32, total_kv_rows: i32, head_dim: i32, num_qo_heads: i32, num_kv_heads: i32, sm_scale: f32, causal: i32,
) -> i32;
pub fn baracuda_kernels_flashinfer_ragged_prefill_can_implement(
batch_size: i32, total_num_rows: i32, total_kv_rows: i32, head_dim: i32,
num_qo_heads: i32, num_kv_heads: i32,
) -> i32;
// Cascade: in-place pairwise LSE merge.
pub fn baracuda_kernels_flashinfer_merge_state_in_place_f16_run(
seq_len: i32, num_heads: i32, head_dim: i32,
v: *mut c_void, s: *mut c_void,
v_other: *const c_void, s_other: *const c_void,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_merge_state_in_place_f16_can_implement(
seq_len: i32, num_heads: i32, head_dim: i32,
) -> i32;
pub fn baracuda_kernels_flashinfer_merge_state_in_place_bf16_run(
seq_len: i32, num_heads: i32, head_dim: i32,
v: *mut c_void, s: *mut c_void,
v_other: *const c_void, s_other: *const c_void,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_merge_state_in_place_bf16_can_implement(
seq_len: i32, num_heads: i32, head_dim: i32,
) -> i32;
pub fn baracuda_kernels_flashinfer_merge_state_in_place_f32_run(
seq_len: i32, num_heads: i32, head_dim: i32,
v: *mut c_void, s: *mut c_void,
v_other: *const c_void, s_other: *const c_void,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_merge_state_in_place_f32_can_implement(
seq_len: i32, num_heads: i32, head_dim: i32,
) -> i32;
pub fn baracuda_kernels_flashinfer_merge_state_in_place_can_implement(
seq_len: i32, num_heads: i32, head_dim: i32,
) -> i32;
// Cascade: many-way merge.
pub fn baracuda_kernels_flashinfer_merge_states_f16_run(
num_index_sets: i32, seq_len: i32, num_heads: i32, head_dim: i32,
v: *const c_void, s: *const c_void,
v_merged: *mut c_void, s_merged: *mut c_void,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_merge_states_f16_can_implement(
num_index_sets: i32, seq_len: i32, num_heads: i32, head_dim: i32,
) -> i32;
pub fn baracuda_kernels_flashinfer_merge_states_bf16_run(
num_index_sets: i32, seq_len: i32, num_heads: i32, head_dim: i32,
v: *const c_void, s: *const c_void,
v_merged: *mut c_void, s_merged: *mut c_void,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_merge_states_bf16_can_implement(
num_index_sets: i32, seq_len: i32, num_heads: i32, head_dim: i32,
) -> i32;
pub fn baracuda_kernels_flashinfer_merge_states_f32_run(
num_index_sets: i32, seq_len: i32, num_heads: i32, head_dim: i32,
v: *const c_void, s: *const c_void,
v_merged: *mut c_void, s_merged: *mut c_void,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_merge_states_f32_can_implement(
num_index_sets: i32, seq_len: i32, num_heads: i32, head_dim: i32,
) -> i32;
pub fn baracuda_kernels_flashinfer_merge_states_can_implement(
num_index_sets: i32, seq_len: i32, num_heads: i32, head_dim: i32,
) -> i32;
// Sort-free sampling: top-K only.
pub fn baracuda_kernels_flashinfer_top_k_sampling_f32_run(
batch: i32, vocab: i32, top_k_val: i32,
deterministic: i32, seed_val: u64, offset_val: u64,
probs: *const c_void, output: *mut c_void, valid: *mut c_void,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_top_k_sampling_f32_can_implement(
batch: i32, vocab: i32, top_k_val: i32,
) -> i32;
// Sort-free sampling: top-P only.
pub fn baracuda_kernels_flashinfer_top_p_sampling_f32_run(
batch: i32, vocab: i32, top_p_val: f32,
deterministic: i32, seed_val: u64, offset_val: u64,
probs: *const c_void, output: *mut c_void, valid: *mut c_void,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_top_p_sampling_f32_can_implement(
batch: i32, vocab: i32, top_p_val: f32,
) -> i32;
// Sort-free sampling: min-P only.
pub fn baracuda_kernels_flashinfer_min_p_sampling_f32_run(
batch: i32, vocab: i32, min_p_val: f32,
deterministic: i32, seed_val: u64, offset_val: u64,
probs: *const c_void, output: *mut c_void, valid: *mut c_void,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_min_p_sampling_f32_can_implement(
batch: i32, vocab: i32, min_p_val: f32,
) -> i32;
// Sort-free sampling: combined top-K + top-P.
pub fn baracuda_kernels_flashinfer_top_k_top_p_sampling_f32_run(
batch: i32, vocab: i32, top_k_val: i32, top_p_val: f32,
deterministic: i32, seed_val: u64, offset_val: u64,
probs: *const c_void, output: *mut c_void, valid: *mut c_void,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_top_k_top_p_sampling_f32_can_implement(
batch: i32, vocab: i32, top_k_val: i32, top_p_val: f32,
) -> i32;
// Per-row sampler parameter arrays (Phase 66 Tier 2). The threshold
// is a device array `[batch]` instead of a scalar. `top_k_arr` is i32
// (converted to float internally for the standalone Top-K sampler).
pub fn baracuda_kernels_flashinfer_top_k_sampling_f32_arr_run(
batch: i32, vocab: i32, top_k_arr: *const c_void,
deterministic: i32, seed_val: u64, offset_val: u64,
probs: *const c_void, output: *mut c_void, valid: *mut c_void, stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_top_k_sampling_f32_arr_can_implement(
batch: i32, vocab: i32,
) -> i32;
pub fn baracuda_kernels_flashinfer_top_p_sampling_f32_arr_run(
batch: i32, vocab: i32, top_p_arr: *const c_void,
deterministic: i32, seed_val: u64, offset_val: u64,
probs: *const c_void, output: *mut c_void, valid: *mut c_void, stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_top_p_sampling_f32_arr_can_implement(
batch: i32, vocab: i32,
) -> i32;
pub fn baracuda_kernels_flashinfer_min_p_sampling_f32_arr_run(
batch: i32, vocab: i32, min_p_arr: *const c_void,
deterministic: i32, seed_val: u64, offset_val: u64,
probs: *const c_void, output: *mut c_void, valid: *mut c_void, stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_min_p_sampling_f32_arr_can_implement(
batch: i32, vocab: i32,
) -> i32;
pub fn baracuda_kernels_flashinfer_top_k_top_p_sampling_f32_arr_run(
batch: i32, vocab: i32, top_k_arr: *const c_void, top_p_arr: *const c_void,
deterministic: i32, seed_val: u64, offset_val: u64,
probs: *const c_void, output: *mut c_void, valid: *mut c_void, stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_top_k_top_p_sampling_f32_arr_can_implement(
batch: i32, vocab: i32,
) -> i32;
// Speculative-decode verification (Phase 66 Tier 2).
pub fn baracuda_kernels_flashinfer_chain_speculative_sampling_f32_run(
batch: i32, num_speculative_tokens: i32, vocab: i32,
deterministic: i32, seed_val: u64, offset_val: u64,
draft_probs: *const c_void, draft_token_ids: *const c_void, target_probs: *const c_void,
output_token_ids: *mut c_void, output_accepted_token_num: *mut c_void,
output_emitted_draft_token_num: *mut c_void, stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_chain_speculative_sampling_f32_can_implement(
batch: i32, num_speculative_tokens: i32, vocab: i32,
) -> i32;
// FP8 KV-cache decode (Phase 66 Tier 2). k_data/v_data are fp8
// (e4m3 / e5m2); q/o are f16 or bf16. Same arg layout as the
// homogeneous `paged_decode_*_run` symbols.
pub fn baracuda_kernels_flashinfer_paged_decode_f16_e4m3_run(
batch_size: i32, page_size: i32, head_dim: i32,
num_qo_heads: i32, num_kv_heads: i32, sm_scale: f32,
k_data: *mut c_void, v_data: *mut c_void,
indices: *mut c_void, indptr: *mut c_void, last_page_len: *mut c_void,
q: *const c_void, o: *mut c_void, lse: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_paged_decode_f16_e4m3_can_implement(
batch_size: i32, page_size: i32, head_dim: i32, num_qo_heads: i32, num_kv_heads: i32, sm_scale: f32,
) -> i32;
pub fn baracuda_kernels_flashinfer_paged_decode_f16_e5m2_run(
batch_size: i32, page_size: i32, head_dim: i32,
num_qo_heads: i32, num_kv_heads: i32, sm_scale: f32,
k_data: *mut c_void, v_data: *mut c_void,
indices: *mut c_void, indptr: *mut c_void, last_page_len: *mut c_void,
q: *const c_void, o: *mut c_void, lse: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_paged_decode_f16_e5m2_can_implement(
batch_size: i32, page_size: i32, head_dim: i32, num_qo_heads: i32, num_kv_heads: i32, sm_scale: f32,
) -> i32;
pub fn baracuda_kernels_flashinfer_paged_decode_bf16_e4m3_run(
batch_size: i32, page_size: i32, head_dim: i32,
num_qo_heads: i32, num_kv_heads: i32, sm_scale: f32,
k_data: *mut c_void, v_data: *mut c_void,
indices: *mut c_void, indptr: *mut c_void, last_page_len: *mut c_void,
q: *const c_void, o: *mut c_void, lse: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_paged_decode_bf16_e4m3_can_implement(
batch_size: i32, page_size: i32, head_dim: i32, num_qo_heads: i32, num_kv_heads: i32, sm_scale: f32,
) -> i32;
pub fn baracuda_kernels_flashinfer_paged_decode_bf16_e5m2_run(
batch_size: i32, page_size: i32, head_dim: i32,
num_qo_heads: i32, num_kv_heads: i32, sm_scale: f32,
k_data: *mut c_void, v_data: *mut c_void,
indices: *mut c_void, indptr: *mut c_void, last_page_len: *mut c_void,
q: *const c_void, o: *mut c_void, lse: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_flashinfer_paged_decode_bf16_e5m2_can_implement(
batch_size: i32, page_size: i32, head_dim: i32, num_qo_heads: i32, num_kv_heads: i32, sm_scale: f32,
) -> i32;
}
// Bespoke token-penalty logit transform (Phase 66 Tier 2). NOT behind the
// `flashinfer` feature — a native baracuda elementwise op.
unsafe extern "C" {
/// `baracuda_kernels_apply_token_penalty_f32_run` (baracuda kernels apply token penalty f32 run).
pub fn baracuda_kernels_apply_token_penalty_f32_run(
batch: i32, vocab: i32, rep_penalty: f32, freq_penalty: f32, pres_penalty: f32,
logits: *mut c_void, counts: *const c_void, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_apply_token_penalty_f32_can_implement` (baracuda kernels apply token penalty f32 can implement).
pub fn baracuda_kernels_apply_token_penalty_f32_can_implement(batch: i32, vocab: i32) -> i32;
}
// ============================================================================
// cuSOLVER — Milestone 6.3 dense linalg
// ============================================================================
//
// Host-API cuSOLVER bindings for the four canonical dense factorizations:
// Cholesky (`potrf`), LU (`getrf`), QR (`geqrf` + `ormqr`), SVD (`gesvd`).
// f32 + f64 only — cuSOLVER's dense API does not expose f16 / bf16 for
// these operations. Cholesky and LU also expose batched variants
// (`*Batched`) that operate on an array of `[N, N]` matrices in a
// single launch; QR / SVD have no batched dense variant on cuSOLVER and
// are 2-D-only at the plan layer.
//
// Linkage: `cargo:rustc-link-lib=dylib=cusolver` + `=cublas` (both
// added in build.rs — cuSOLVER's dense API depends on cuBLAS). On Linux
// these resolve to `libcusolver.so` / `libcublas.so`; on Windows to
// `cusolver64_*.dll` / `cublas64_*.dll` (loaded from `CUDA_PATH\bin`).
//
// All cuSOLVER routines are column-major (matching LAPACK convention).
// The safe-plan layer in `baracuda-kernels` handles the row-major →
// column-major adapter — for symmetric ops (Cholesky) this is a uplo-flip
// (row-major lower-L over storage `S` is bit-identical to column-major
// upper-U over the same `S`); for non-symmetric ops (LU / QR / SVD) the
// plan documents that the input/output is interpreted as the transpose.
/// Opaque cuSOLVER dense handle. Stateful object; the plan layer creates
/// one lazily on first `run` and reuses across launches.
#[allow(non_camel_case_types)]
pub type cusolverDnHandle_t = *mut c_void;
/// Opaque cuBLAS handle. Used by `cublas*geqrfBatched` (which lives in
/// cuBLAS, not cuSOLVER) and any future cuBLAS-routed linalg paths.
#[allow(non_camel_case_types)]
pub type cublasHandle_t = *mut c_void;
/// Opaque cuSOLVER Jacobi-SVD parameter object. Stateful; created
/// once per plan, reused across launches, destroyed on plan drop.
/// Used by `cusolverDn*gesvdjBatched` for the batched-SVD path.
#[allow(non_camel_case_types)]
pub type gesvdjInfo_t = *mut c_void;
/// cuBLAS fill-mode tag re-used by cuSOLVER for triangular factorizations.
/// `CUBLAS_FILL_MODE_LOWER = 0`, `CUBLAS_FILL_MODE_UPPER = 1`.
#[allow(non_camel_case_types)]
pub type cublasFillMode_t = i32;
/// `CUBLAS_FILL_MODE_LOWER` — pass to `potrf` to request the lower-
/// triangular Cholesky factor.
pub const CUBLAS_FILL_MODE_LOWER: i32 = 0;
/// `CUBLAS_FILL_MODE_UPPER` — pass to `potrf` to request the upper-
/// triangular Cholesky factor.
pub const CUBLAS_FILL_MODE_UPPER: i32 = 1;
/// `CUBLAS_OP_N` — no transpose. Used by `ormqr` to control whether to
/// apply `Q` or `Q^T`.
pub const CUBLAS_OP_N: i32 = 0;
/// `CUBLAS_OP_T` — transpose.
pub const CUBLAS_OP_T: i32 = 1;
/// `CUBLAS_OP_C` — conjugate transpose (only meaningful for complex
/// dtypes). Used by `cusolverDn{C,Z}unmqr` to apply `Q^H`.
pub const CUBLAS_OP_C: i32 = 2;
// --- cudaDataType / cublasComputeType_t tags (Phase 47) -----------------
// NOTE: CUDA_R_{16,32,64}F are already defined further down in the cuSOLVER
// section (same numeric values from <library_types.h>); they're commented
// out here to avoid the duplicate-definition error. Only CUDA_R_16BF +
// the CUBLAS_COMPUTE_* tags are unique to this Phase 47 block.
// pub const CUDA_R_16F: i32 = 2;
// pub const CUDA_R_32F: i32 = 0;
// pub const CUDA_R_64F: i32 = 1;
/// `CUDA_R_16BF` — bfloat16 (real). Storage tag for `__nv_bfloat16`.
pub const CUDA_R_16BF: i32 = 14;
/// `CUBLAS_COMPUTE_32F` — fp32 accumulator.
pub const CUBLAS_COMPUTE_32F: i32 = 68;
/// `CUBLAS_COMPUTE_64F` — fp64 accumulator.
pub const CUBLAS_COMPUTE_64F: i32 = 70;
/// `CUBLAS_GEMM_DEFAULT` — let cuBLAS pick the algorithm.
pub const CUBLAS_GEMM_DEFAULT: i32 = -1;
/// `CUBLAS_SIDE_LEFT` — `Q` is applied from the left in `ormqr`
/// (`C := Q · C` or `C := Q^T · C`).
pub const CUBLAS_SIDE_LEFT: i32 = 0;
/// `CUBLAS_SIDE_RIGHT` — `Q` is applied from the right.
pub const CUBLAS_SIDE_RIGHT: i32 = 1;
/// cuBLAS diag-type tag for triangular solves (`trsm`).
/// `CUBLAS_DIAG_NON_UNIT = 0`, `CUBLAS_DIAG_UNIT = 1`.
#[allow(non_camel_case_types)]
pub type cublasDiagType_t = i32;
/// `CUBLAS_DIAG_NON_UNIT` — `trsm` reads the actual diagonal of `A`.
/// Used by the LstSq QR-fallback path for the back-substitution
/// `R · X = Q^T · B`, where `R`'s diagonal is the meaningful pivots.
pub const CUBLAS_DIAG_NON_UNIT: i32 = 0;
/// `CUBLAS_DIAG_UNIT` — `trsm` treats the diagonal as all-1s
/// (unit-triangular). Not used by the current plan layer; surfaced
/// for completeness.
pub const CUBLAS_DIAG_UNIT: i32 = 1;
/// `CUSOLVER_STATUS_SUCCESS` — the only success code. Any non-zero
/// return from a cuSOLVER routine is mapped to a negative status at the
/// safe-plan layer for distinct error reporting.
pub const CUSOLVER_STATUS_SUCCESS: i32 = 0;
/// cuSOLVER eig-mode enum tag (used by `syevd` / `heevd` / `Xgeev`).
/// `0 = NOVECTOR` (compute eigenvalues only), `1 = VECTOR` (eigenvalues +
/// eigenvectors). Routed through as an `i32` for the legacy syevd /
/// heevd APIs. The `CUSOLVER_EIG_MODE_NOVECTOR` / `_VECTOR` constants
/// live further down (originally introduced for `gesvdjBatched`'s
/// `jobz` argument; reused verbatim here for the eig family).
#[allow(non_camel_case_types)]
pub type cusolverEigMode_t = i32;
/// `cudaDataType` tag used by the 64-bit cuSOLVER APIs (`Xgeev`,
/// `Xgesvd`, …) to identify tensor element types. These constants
/// originate in `<library_types.h>` and are stable across CUDA versions.
#[allow(non_camel_case_types)]
pub type cudaDataType = i32;
/// `CUDA_R_32F` — real `f32`.
pub const CUDA_R_32F: i32 = 0;
/// `CUDA_R_64F` — real `f64`.
pub const CUDA_R_64F: i32 = 1;
/// `CUDA_R_16F` — real `f16`.
pub const CUDA_R_16F: i32 = 2;
/// `CUDA_C_32F` — complex `f32` (interleaved real/imag).
pub const CUDA_C_32F: i32 = 4;
/// `CUDA_C_64F` — complex `f64` (interleaved real/imag).
pub const CUDA_C_64F: i32 = 5;
/// Opaque parameter struct used by the 64-bit cuSOLVER APIs (`Xgeev`,
/// `Xpotrf`, …). The struct holds advanced configuration (algorithm
/// choice, precision modes) — for the trailblazer the plan layer leaves
/// it at defaults. Created via `cusolverDnCreateParams` and destroyed via
/// `cusolverDnDestroyParams`.
#[allow(non_camel_case_types)]
pub type cusolverDnParams_t = *mut c_void;
/// ABI-compatible single-precision complex struct, matching `cuComplex`
/// from `<cuComplex.h>` (interleaved real/imag `f32`). Identical layout
/// to [`crate::cufftComplex`] and to the safe-side [`Complex32`] from
/// `baracuda-kernels-types` — a `DeviceBuffer<Complex32>` can be cast
/// to a `*mut cuComplex` for the cuSOLVER complex APIs without copy.
#[repr(C)]
#[derive(Copy, Clone, Debug, Default, PartialEq)]
#[allow(non_camel_case_types)]
pub struct cuComplex {
/// Real component.
pub x: f32,
/// Imaginary component.
pub y: f32,
}
/// ABI-compatible double-precision complex struct, matching
/// `cuDoubleComplex` from `<cuComplex.h>`. Sibling to [`cuComplex`].
#[repr(C)]
#[derive(Copy, Clone, Debug, Default, PartialEq)]
#[allow(non_camel_case_types)]
pub struct cuDoubleComplex {
/// Real component.
pub x: f64,
/// Imaginary component.
pub y: f64,
}
/// `cuFloatComplex` is the canonical CUDA name for the single-precision
/// complex struct — an alias for [`cuComplex`]. Surfaced so cuSOLVER's
/// complex APIs (`cusolverDn{C,Z}unmqr`, …) can spell their signatures
/// in the same vocabulary as the NVIDIA headers.
#[allow(non_camel_case_types)]
pub type cuFloatComplex = cuComplex;
unsafe extern "C" {
// ----- handle lifecycle ----------------------------------------------
/// `cusolverDnCreate(handle)`. Returns 0 on success.
///
/// # Safety
/// `handle` must point to writable storage for one `cusolverDnHandle_t`.
pub fn cusolverDnCreate(handle: *mut cusolverDnHandle_t) -> i32;
/// `cusolverDnDestroy(handle)`. Returns 0 on success.
///
/// # Safety
/// `handle` must be a valid handle returned by `cusolverDnCreate` that
/// has not been previously destroyed.
pub fn cusolverDnDestroy(handle: cusolverDnHandle_t) -> i32;
/// `cusolverDnSetStream(handle, stream)`. Binds subsequent cuSOLVER
/// calls to the given CUDA stream. Returns 0 on success.
///
/// # Safety
/// `handle` must be a live cuSOLVER handle; `stream` must be a valid
/// CUDA stream in the current context (or null for the default stream).
pub fn cusolverDnSetStream(handle: cusolverDnHandle_t, stream: *mut c_void) -> i32;
// ----- Cholesky: potrf (f32 / f64) -----------------------------------
/// `cusolverDnSpotrf_bufferSize` — query workspace bytes (as element
/// count, must be multiplied by `sizeof(T)` for `cudaMalloc`).
///
/// # Safety
/// `handle` live; `A` device pointer to `n*n` `float` cells with leading
/// dimension `lda`; `lwork` writable storage for one `int`.
pub fn cusolverDnSpotrf_bufferSize(
handle: cusolverDnHandle_t,
uplo: cublasFillMode_t,
n: i32,
a: *mut f32,
lda: i32,
lwork: *mut i32,
) -> i32;
/// `cusolverDnSpotrf` — Cholesky factorization in-place (`A := L`
/// or `A := U`). Writes the unused triangle untouched. `dev_info`
/// returns 0 on success, `k > 0` if the leading `k`-minor is not
/// positive definite (factorization halted at step `k`).
///
/// # Safety
/// All pointers reference device memory; `workspace` has at least
/// `lwork * sizeof(float)` bytes; `dev_info` references one `int`.
pub fn cusolverDnSpotrf(
handle: cusolverDnHandle_t,
uplo: cublasFillMode_t,
n: i32,
a: *mut f32,
lda: i32,
workspace: *mut f32,
lwork: i32,
dev_info: *mut i32,
) -> i32;
/// `cusolverDnDpotrf_bufferSize`. f64 analogue.
///
/// # Safety
/// Same as the f32 variant with `f64` storage.
pub fn cusolverDnDpotrf_bufferSize(
handle: cusolverDnHandle_t,
uplo: cublasFillMode_t,
n: i32,
a: *mut f64,
lda: i32,
lwork: *mut i32,
) -> i32;
/// `cusolverDnDpotrf`. f64 analogue.
///
/// # Safety
/// Same as the f32 variant with `f64` storage.
pub fn cusolverDnDpotrf(
handle: cusolverDnHandle_t,
uplo: cublasFillMode_t,
n: i32,
a: *mut f64,
lda: i32,
workspace: *mut f64,
lwork: i32,
dev_info: *mut i32,
) -> i32;
// ----- Cholesky batched ----------------------------------------------
/// `cusolverDnSpotrfBatched(handle, uplo, n, Aarray, lda, infoArray,
/// batchSize)`. Each matrix in `Aarray[batch_size]` is factored
/// independently in-place. Returns 0 on success; per-matrix factor
/// info lands in `infoArray[i]`.
///
/// # Safety
/// `Aarray` is a device-resident array of `batch_size` pointers, each
/// pointing to an `n × n` `float` matrix with leading dimension `lda`.
/// `infoArray` is a device-resident `int[batch_size]` written by the
/// kernel. Note: cuSOLVER's batched API does **not** take a workspace
/// argument — the library allocates internally.
pub fn cusolverDnSpotrfBatched(
handle: cusolverDnHandle_t,
uplo: cublasFillMode_t,
n: i32,
a_array: *mut *mut f32,
lda: i32,
info_array: *mut i32,
batch_size: i32,
) -> i32;
/// `cusolverDnDpotrfBatched`. f64 analogue.
///
/// # Safety
/// Same as the f32 variant with `f64` storage.
pub fn cusolverDnDpotrfBatched(
handle: cusolverDnHandle_t,
uplo: cublasFillMode_t,
n: i32,
a_array: *mut *mut f64,
lda: i32,
info_array: *mut i32,
batch_size: i32,
) -> i32;
// ----- LU: getrf (f32 / f64) -----------------------------------------
/// `cusolverDnSgetrf_bufferSize` — query workspace element count.
///
/// # Safety
/// `handle` live; `A` device pointer to `m*n` `float` cells with
/// leading dimension `lda`; `lwork` writable `int`.
pub fn cusolverDnSgetrf_bufferSize(
handle: cusolverDnHandle_t,
m: i32,
n: i32,
a: *mut f32,
lda: i32,
lwork: *mut i32,
) -> i32;
/// `cusolverDnSgetrf` — LU factorization with partial pivoting in
/// place. `A := L · U` (with `L` unit-diagonal, stored in the strict
/// lower triangle; `U` in the upper triangle). `ipiv[i]` is the row
/// swap performed at step `i` (1-based per LAPACK convention).
///
/// # Safety
/// All pointers reference device memory; `workspace` ≥
/// `lwork * sizeof(float)` bytes; `ipiv` ≥ `min(m, n) * sizeof(int)`
/// bytes; `dev_info` is one `int`.
pub fn cusolverDnSgetrf(
handle: cusolverDnHandle_t,
m: i32,
n: i32,
a: *mut f32,
lda: i32,
workspace: *mut f32,
ipiv: *mut i32,
dev_info: *mut i32,
) -> i32;
/// `cusolverDnDgetrf_bufferSize`. f64 analogue.
///
/// # Safety
/// Same as the f32 variant with `f64` storage.
pub fn cusolverDnDgetrf_bufferSize(
handle: cusolverDnHandle_t,
m: i32,
n: i32,
a: *mut f64,
lda: i32,
lwork: *mut i32,
) -> i32;
/// `cusolverDnDgetrf`. f64 analogue.
///
/// # Safety
/// Same as the f32 variant with `f64` storage.
pub fn cusolverDnDgetrf(
handle: cusolverDnHandle_t,
m: i32,
n: i32,
a: *mut f64,
lda: i32,
workspace: *mut f64,
ipiv: *mut i32,
dev_info: *mut i32,
) -> i32;
// ----- LU solve: getrs (f32 / f64) -----------------------------------
//
// `getrs` consumes the packed `LU` factors + pivot vector produced
// by `getrf` and solves `op(A) · X = B` in place over `B`. cuSOLVER
// does not expose a `_bufferSize` query for `getrs` — the routine
// is workspace-free.
/// `cusolverDnSgetrs` — solve `op(A) · X = B` using the packed `LU`
/// + pivot produced by `cusolverDnSgetrf`. `B` is overwritten in
/// place with the solution `X`. `trans` selects `op(A)`:
/// `CUBLAS_OP_N` for `A`, `CUBLAS_OP_T` for `A^T`.
///
/// # Safety
/// `handle` live + stream-bound; `A` is the packed `getrf` output
/// `n × n` `float` (lda ≥ n); `ipiv` is the 1-based pivot vector of
/// length `n` returned by `getrf`; `B` is `n × nrhs` `float` (ldb ≥
/// n); `dev_info` is one writable `int`.
pub fn cusolverDnSgetrs(
handle: cusolverDnHandle_t,
trans: i32,
n: i32,
nrhs: i32,
a: *const f32,
lda: i32,
ipiv: *const i32,
b: *mut f32,
ldb: i32,
dev_info: *mut i32,
) -> i32;
/// `cusolverDnDgetrs`. f64 analogue.
///
/// # Safety
/// Same as the f32 variant with `f64` storage.
pub fn cusolverDnDgetrs(
handle: cusolverDnHandle_t,
trans: i32,
n: i32,
nrhs: i32,
a: *const f64,
lda: i32,
ipiv: *const i32,
b: *mut f64,
ldb: i32,
dev_info: *mut i32,
) -> i32;
// ----- QR: geqrf + ormqr (f32 / f64) ---------------------------------
//
// Note: cuSOLVER's dense API does not expose a batched LU
// (`cublasSgetrfBatched` lives in cuBLAS — wiring batched LU
// through cuBLAS is deferred to a future milestone). Batched
// Cholesky stays in cuSOLVER (`*potrfBatched` above).
/// `cusolverDnSgeqrf_bufferSize`.
///
/// # Safety
/// `handle` live; `A` device `m × n` `float` with leading dimension
/// `lda`; `lwork` writable `int`.
pub fn cusolverDnSgeqrf_bufferSize(
handle: cusolverDnHandle_t,
m: i32,
n: i32,
a: *mut f32,
lda: i32,
lwork: *mut i32,
) -> i32;
/// `cusolverDnSgeqrf` — QR factorization in place. `A` is overwritten:
/// upper triangle = `R`, strict lower triangle + `tau` = Householder
/// reflectors that encode `Q`. To materialize `Q` as a dense matrix,
/// follow with `ormqr` against an identity.
///
/// # Safety
/// All pointers reference device memory; `tau ≥ min(m, n) * sizeof(T)`;
/// `workspace ≥ lwork * sizeof(T)`; `dev_info` is one `int`.
pub fn cusolverDnSgeqrf(
handle: cusolverDnHandle_t,
m: i32,
n: i32,
a: *mut f32,
lda: i32,
tau: *mut f32,
workspace: *mut f32,
lwork: i32,
dev_info: *mut i32,
) -> i32;
/// `cusolverDnDgeqrf_bufferSize`. f64 analogue.
///
/// # Safety
/// Same as the f32 variant with `f64` storage.
pub fn cusolverDnDgeqrf_bufferSize(
handle: cusolverDnHandle_t,
m: i32,
n: i32,
a: *mut f64,
lda: i32,
lwork: *mut i32,
) -> i32;
/// `cusolverDnDgeqrf`. f64 analogue.
///
/// # Safety
/// Same as the f32 variant with `f64` storage.
pub fn cusolverDnDgeqrf(
handle: cusolverDnHandle_t,
m: i32,
n: i32,
a: *mut f64,
lda: i32,
tau: *mut f64,
workspace: *mut f64,
lwork: i32,
dev_info: *mut i32,
) -> i32;
/// `cusolverDnSormqr_bufferSize`. `trans` selects `Q` vs `Q^T`;
/// `side` selects left vs right multiply.
///
/// # Safety
/// `handle` live; `A` / `C` are the `geqrf`-output matrix and the
/// target matrix respectively; `tau` is the Householder scalars from
/// `geqrf`. `lwork` writable `int`.
pub fn cusolverDnSormqr_bufferSize(
handle: cusolverDnHandle_t,
side: i32,
trans: i32,
m: i32,
n: i32,
k: i32,
a: *const f32,
lda: i32,
tau: *const f32,
c: *const f32,
ldc: i32,
lwork: *mut i32,
) -> i32;
/// `cusolverDnSormqr` — apply `Q` (or `Q^T`) from `geqrf` output to
/// a matrix `C` in place. With `C = I` this materializes `Q` as a
/// dense matrix for the "thin" or "full" QR.
///
/// # Safety
/// All pointers reference device memory; `workspace ≥
/// lwork * sizeof(T)`; `dev_info` is one `int`.
pub fn cusolverDnSormqr(
handle: cusolverDnHandle_t,
side: i32,
trans: i32,
m: i32,
n: i32,
k: i32,
a: *const f32,
lda: i32,
tau: *const f32,
c: *mut f32,
ldc: i32,
workspace: *mut f32,
lwork: i32,
dev_info: *mut i32,
) -> i32;
/// `cusolverDnDormqr_bufferSize`. f64 analogue.
///
/// # Safety
/// Same as the f32 variant with `f64` storage.
pub fn cusolverDnDormqr_bufferSize(
handle: cusolverDnHandle_t,
side: i32,
trans: i32,
m: i32,
n: i32,
k: i32,
a: *const f64,
lda: i32,
tau: *const f64,
c: *const f64,
ldc: i32,
lwork: *mut i32,
) -> i32;
/// `cusolverDnDormqr`. f64 analogue.
///
/// # Safety
/// Same as the f32 variant with `f64` storage.
pub fn cusolverDnDormqr(
handle: cusolverDnHandle_t,
side: i32,
trans: i32,
m: i32,
n: i32,
k: i32,
a: *const f64,
lda: i32,
tau: *const f64,
c: *mut f64,
ldc: i32,
workspace: *mut f64,
lwork: i32,
dev_info: *mut i32,
) -> i32;
// ----- QR factorization (complex): geqrf (Complex32 / Complex64) ------
/// `cusolverDnCgeqrf_bufferSize` — workspace query for single-precision
/// complex QR factorization. Mirrors `cusolverDnSgeqrf_bufferSize`.
///
/// # Safety
/// `handle` live; `A` device `m × n` `cuFloatComplex` with leading
/// dimension `lda`; `lwork` writable `int`.
pub fn cusolverDnCgeqrf_bufferSize(
handle: cusolverDnHandle_t,
m: i32,
n: i32,
a: *mut cuFloatComplex,
lda: i32,
lwork: *mut i32,
) -> i32;
/// `cusolverDnCgeqrf` — single-precision complex QR factorization,
/// in place. The packed output uses the same convention as the real
/// variant: strict lower triangle + `tau` encode the Householder
/// reflectors; the upper triangle holds `R`.
///
/// # Safety
/// All pointers reference device memory; `tau ≥ min(m, n)` cells;
/// `workspace ≥ lwork * sizeof(cuFloatComplex)`; `dev_info` one `int`.
pub fn cusolverDnCgeqrf(
handle: cusolverDnHandle_t,
m: i32,
n: i32,
a: *mut cuFloatComplex,
lda: i32,
tau: *mut cuFloatComplex,
workspace: *mut cuFloatComplex,
lwork: i32,
dev_info: *mut i32,
) -> i32;
/// `cusolverDnZgeqrf_bufferSize`. f64-complex analogue of the C variant.
///
/// # Safety
/// Same as the C variant with `cuDoubleComplex` storage.
pub fn cusolverDnZgeqrf_bufferSize(
handle: cusolverDnHandle_t,
m: i32,
n: i32,
a: *mut cuDoubleComplex,
lda: i32,
lwork: *mut i32,
) -> i32;
/// `cusolverDnZgeqrf` — double-precision complex QR factorization.
///
/// # Safety
/// Same as the C variant with `cuDoubleComplex` storage.
pub fn cusolverDnZgeqrf(
handle: cusolverDnHandle_t,
m: i32,
n: i32,
a: *mut cuDoubleComplex,
lda: i32,
tau: *mut cuDoubleComplex,
workspace: *mut cuDoubleComplex,
lwork: i32,
dev_info: *mut i32,
) -> i32;
// ----- Apply Q from QR (complex): unmqr (Complex32 / Complex64) -------
//
// cuSOLVER spells the complex apply-Q routine `unmqr` ("unitary mqr")
// — the same API surface as `ormqr` but with `cuComplex` /
// `cuDoubleComplex` storage. `trans = CUBLAS_OP_C` selects `Q^H`
// (conjugate transpose).
/// `cusolverDnCunmqr_bufferSize`.
///
/// # Safety
/// All pointers device-resident; `lwork` writable `int`.
pub fn cusolverDnCunmqr_bufferSize(
handle: cusolverDnHandle_t,
side: i32,
trans: i32,
m: i32,
n: i32,
k: i32,
a: *const cuFloatComplex,
lda: i32,
tau: *const cuFloatComplex,
c: *const cuFloatComplex,
ldc: i32,
lwork: *mut i32,
) -> i32;
/// `cusolverDnCunmqr` — apply `Q`, `Q^T`, or `Q^H` from a complex
/// `geqrf` factorization to a complex `C` in place.
///
/// # Safety
/// All pointers reference device memory; `workspace ≥
/// lwork * sizeof(cuFloatComplex)`; `dev_info` is one `int`.
pub fn cusolverDnCunmqr(
handle: cusolverDnHandle_t,
side: i32,
trans: i32,
m: i32,
n: i32,
k: i32,
a: *const cuFloatComplex,
lda: i32,
tau: *const cuFloatComplex,
c: *mut cuFloatComplex,
ldc: i32,
workspace: *mut cuFloatComplex,
lwork: i32,
dev_info: *mut i32,
) -> i32;
/// `cusolverDnZunmqr_bufferSize`. f64-complex analogue.
///
/// # Safety
/// Same as the C variant with `cuDoubleComplex` storage.
pub fn cusolverDnZunmqr_bufferSize(
handle: cusolverDnHandle_t,
side: i32,
trans: i32,
m: i32,
n: i32,
k: i32,
a: *const cuDoubleComplex,
lda: i32,
tau: *const cuDoubleComplex,
c: *const cuDoubleComplex,
ldc: i32,
lwork: *mut i32,
) -> i32;
/// `cusolverDnZunmqr`. f64-complex analogue.
///
/// # Safety
/// Same as the C variant with `cuDoubleComplex` storage.
pub fn cusolverDnZunmqr(
handle: cusolverDnHandle_t,
side: i32,
trans: i32,
m: i32,
n: i32,
k: i32,
a: *const cuDoubleComplex,
lda: i32,
tau: *const cuDoubleComplex,
c: *mut cuDoubleComplex,
ldc: i32,
workspace: *mut cuDoubleComplex,
lwork: i32,
dev_info: *mut i32,
) -> i32;
// ----- SVD: gesvd (f32 / f64) ----------------------------------------
/// `cusolverDnSgesvd_bufferSize`.
///
/// # Safety
/// `handle` live; `lwork` writable `int`. cuSOLVER's `gesvd_bufferSize`
/// signature does not take a matrix pointer (m and n suffice).
pub fn cusolverDnSgesvd_bufferSize(
handle: cusolverDnHandle_t,
m: i32,
n: i32,
lwork: *mut i32,
) -> i32;
/// `cusolverDnSgesvd` — SVD: `A = U · diag(S) · V^T`. The `jobu` /
/// `jobv` characters are ASCII bytes: `'A'` (full U/V^T), `'S'` (thin
/// U/V^T), `'O'` (overwrite A — disallowed at plan layer), `'N'`
/// (skip).
///
/// # Safety
/// All pointers reference device memory; `S ≥ min(m, n) * sizeof(T)`;
/// `U ≥ m*m * sizeof(T)` (full) or `m * min(m,n) * sizeof(T)` (thin);
/// `VT ≥ n*n * sizeof(T)` (full) or `min(m,n) * n * sizeof(T)` (thin);
/// `workspace ≥ lwork * sizeof(T)`; `rwork` may be null for real
/// dtypes; `dev_info` is one `int`. Important: cuSOLVER's `gesvd`
/// **requires** `m ≥ n` — callers that need `m < n` must transpose
/// the input first.
pub fn cusolverDnSgesvd(
handle: cusolverDnHandle_t,
jobu: u8,
jobv: u8,
m: i32,
n: i32,
a: *mut f32,
lda: i32,
s: *mut f32,
u: *mut f32,
ldu: i32,
vt: *mut f32,
ldvt: i32,
workspace: *mut f32,
lwork: i32,
rwork: *mut f32,
dev_info: *mut i32,
) -> i32;
/// `cusolverDnDgesvd_bufferSize`. f64 analogue.
///
/// # Safety
/// Same as the f32 variant.
pub fn cusolverDnDgesvd_bufferSize(
handle: cusolverDnHandle_t,
m: i32,
n: i32,
lwork: *mut i32,
) -> i32;
/// `cusolverDnDgesvd`. f64 analogue.
///
/// # Safety
/// Same as the f32 variant with `f64` storage.
pub fn cusolverDnDgesvd(
handle: cusolverDnHandle_t,
jobu: u8,
jobv: u8,
m: i32,
n: i32,
a: *mut f64,
lda: i32,
s: *mut f64,
u: *mut f64,
ldu: i32,
vt: *mut f64,
ldvt: i32,
workspace: *mut f64,
lwork: i32,
rwork: *mut f64,
dev_info: *mut i32,
) -> i32;
// ----- Symmetric / Hermitian eigendecomposition: syevd / heevd ------
//
// `syevd` / `heevd` compute the eigenvalues + eigenvectors of a real
// symmetric (`syevd`) or complex Hermitian (`heevd`) matrix using the
// divide-and-conquer algorithm. The input matrix is overwritten in
// place with the eigenvectors (column-major); a separate `W` vector
// receives the (always-real) eigenvalues.
//
// `jobz` is `CUSOLVER_EIG_MODE_VECTOR` (compute eigenvectors) or
// `CUSOLVER_EIG_MODE_NOVECTOR` (eigenvalues only). `uplo` selects
// which triangle of the input to read (`CUBLAS_FILL_MODE_LOWER` /
// `_UPPER`).
/// `cusolverDnSsyevd_bufferSize` — query workspace element count for
/// real-symmetric divide-and-conquer eigh, f32.
///
/// # Safety
/// `handle` live; `A` device `n × n` `float` (lda ≥ n); `W` device
/// `float` of length `n`; `lwork` writable `int`.
pub fn cusolverDnSsyevd_bufferSize(
handle: cusolverDnHandle_t,
jobz: cusolverEigMode_t,
uplo: cublasFillMode_t,
n: i32,
a: *const f32,
lda: i32,
w: *const f32,
lwork: *mut i32,
) -> i32;
/// `cusolverDnSsyevd` — real-symmetric eigh, f32. `A` is overwritten
/// in place with the eigenvectors (column-major) when `jobz ==
/// VECTOR`. `W` receives the `n` eigenvalues sorted ascending.
///
/// # Safety
/// All pointers reference device memory; `workspace ≥ lwork *
/// sizeof(float)`; `dev_info` is one writable `int`.
pub fn cusolverDnSsyevd(
handle: cusolverDnHandle_t,
jobz: cusolverEigMode_t,
uplo: cublasFillMode_t,
n: i32,
a: *mut f32,
lda: i32,
w: *mut f32,
workspace: *mut f32,
lwork: i32,
dev_info: *mut i32,
) -> i32;
/// `cusolverDnDsyevd_bufferSize`. f64 analogue.
///
/// # Safety
/// Same as the f32 variant with `f64` storage.
pub fn cusolverDnDsyevd_bufferSize(
handle: cusolverDnHandle_t,
jobz: cusolverEigMode_t,
uplo: cublasFillMode_t,
n: i32,
a: *const f64,
lda: i32,
w: *const f64,
lwork: *mut i32,
) -> i32;
/// `cusolverDnDsyevd`. f64 analogue.
///
/// # Safety
/// Same as the f32 variant with `f64` storage.
pub fn cusolverDnDsyevd(
handle: cusolverDnHandle_t,
jobz: cusolverEigMode_t,
uplo: cublasFillMode_t,
n: i32,
a: *mut f64,
lda: i32,
w: *mut f64,
workspace: *mut f64,
lwork: i32,
dev_info: *mut i32,
) -> i32;
/// `cusolverDnCheevd_bufferSize` — complex-Hermitian divide-and-conquer
/// eigh, single precision (`Complex32`). Eigenvalues are real-valued
/// `float`.
///
/// # Safety
/// `handle` live; `A` device `n × n` `cuComplex` (lda ≥ n); `W`
/// device `float` of length `n`; `lwork` writable `int`.
pub fn cusolverDnCheevd_bufferSize(
handle: cusolverDnHandle_t,
jobz: cusolverEigMode_t,
uplo: cublasFillMode_t,
n: i32,
a: *const cuComplex,
lda: i32,
w: *const f32,
lwork: *mut i32,
) -> i32;
/// `cusolverDnCheevd` — complex-Hermitian eigh (`Complex32`). `A` is
/// overwritten in place with the eigenvectors (column-major); `W`
/// receives the `n` real eigenvalues sorted ascending.
///
/// # Safety
/// All pointers reference device memory; `workspace ≥ lwork *
/// sizeof(cuComplex)`; `dev_info` is one writable `int`.
pub fn cusolverDnCheevd(
handle: cusolverDnHandle_t,
jobz: cusolverEigMode_t,
uplo: cublasFillMode_t,
n: i32,
a: *mut cuComplex,
lda: i32,
w: *mut f32,
workspace: *mut cuComplex,
lwork: i32,
dev_info: *mut i32,
) -> i32;
/// `cusolverDnZheevd_bufferSize`. `Complex64` analogue.
///
/// # Safety
/// Same as the `Cheevd` variant with `cuDoubleComplex` / `f64`
/// storage.
pub fn cusolverDnZheevd_bufferSize(
handle: cusolverDnHandle_t,
jobz: cusolverEigMode_t,
uplo: cublasFillMode_t,
n: i32,
a: *const cuDoubleComplex,
lda: i32,
w: *const f64,
lwork: *mut i32,
) -> i32;
/// `cusolverDnZheevd`. `Complex64` analogue.
///
/// # Safety
/// Same as the `Cheevd` variant with `cuDoubleComplex` / `f64`
/// storage.
pub fn cusolverDnZheevd(
handle: cusolverDnHandle_t,
jobz: cusolverEigMode_t,
uplo: cublasFillMode_t,
n: i32,
a: *mut cuDoubleComplex,
lda: i32,
w: *mut f64,
workspace: *mut cuDoubleComplex,
lwork: i32,
dev_info: *mut i32,
) -> i32;
// ----- Generic eigendecomposition: Xgeev (64-bit API) ---------------
//
// `Xgeev` is the cuSOLVER 11+ 64-bit-index API for the general (non-
// symmetric) eigendecomposition. Differences from the legacy `Sgeev`
// / `Dgeev` family:
//
// - Takes a `cusolverDnParams_t` opaque settings struct (created
// via `cusolverDnCreateParams`, destroyed via
// `cusolverDnDestroyParams`).
// - Indices are `int64_t`, not `int`.
// - Workspace sizes are `size_t` byte counts, NOT element counts;
// the buffer-size query returns BOTH a host-side and a device-side
// byte count, and `Xgeev` itself takes both buffers.
// - Tensor element types are passed as `cudaDataType` tags (CUDA_R_32F
// / CUDA_R_64F / CUDA_C_32F / CUDA_C_64F). The same routine handles
// all four input dtypes.
// - Eigenvalues `W` are **always complex** (`cudaDataType` must be
// CUDA_C_32F or CUDA_C_64F) — for real input the complex-conjugate
// pairs are stored explicitly rather than packed into a wr/wi
// LAPACK-style split.
//
// `jobvl` / `jobvr` are `CUSOLVER_EIG_MODE_VECTOR` (compute) or
// `CUSOLVER_EIG_MODE_NOVECTOR` (skip — pass null for the corresponding
// VL / VR pointers in that case).
/// `cusolverDnCreateParams` — allocate the opaque params struct used
/// by all 64-bit cuSOLVER APIs. Plan layer creates one lazily on
/// first `run` (mirroring the handle lifecycle).
///
/// # Safety
/// `params` must point to writable storage for one `cusolverDnParams_t`.
pub fn cusolverDnCreateParams(params: *mut cusolverDnParams_t) -> i32;
/// `cusolverDnDestroyParams`. Returns 0 on success.
///
/// # Safety
/// `params` must be a live params struct returned by
/// `cusolverDnCreateParams` that has not already been destroyed.
pub fn cusolverDnDestroyParams(params: cusolverDnParams_t) -> i32;
/// `cusolverDnXgeev_bufferSize` — query the host + device byte
/// counts for `cusolverDnXgeev` at the given problem size and
/// element types. The two output pointers receive byte counts (NOT
/// element counts — different from the legacy `_bufferSize` APIs).
///
/// # Safety
/// `handle` / `params` live; pointer args reference device memory of
/// the indicated `cudaDataType`; `workspace_in_bytes_on_device` and
/// `workspace_in_bytes_on_host` point to writable `size_t`.
pub fn cusolverDnXgeev_bufferSize(
handle: cusolverDnHandle_t,
params: cusolverDnParams_t,
jobvl: cusolverEigMode_t,
jobvr: cusolverEigMode_t,
n: i64,
data_type_a: cudaDataType,
a: *const c_void,
lda: i64,
data_type_w: cudaDataType,
w: *const c_void,
data_type_vl: cudaDataType,
vl: *const c_void,
ldvl: i64,
data_type_vr: cudaDataType,
vr: *const c_void,
ldvr: i64,
compute_type: cudaDataType,
workspace_in_bytes_on_device: *mut usize,
workspace_in_bytes_on_host: *mut usize,
) -> i32;
/// `cusolverDnXgeev` — general (non-symmetric) eigendecomposition.
/// `A` is **destroyed in place** (used as scratch by the LAPACK-
/// equivalent algorithm). `W` receives the `n` complex eigenvalues;
/// `VL` / `VR` (when requested) receive the column-major left /
/// right complex eigenvectors. For non-Hermitian input the
/// eigenvalues can be complex even when the input is real, hence
/// the always-complex `W` storage.
///
/// # Safety
/// All tensor pointer args reference device memory of the indicated
/// `cudaDataType`; `workspace_on_device` ≥ `workspace_in_bytes_on_device`
/// device bytes; `workspace_on_host` ≥ `workspace_in_bytes_on_host`
/// host bytes (or null if `workspace_in_bytes_on_host == 0`); `info`
/// is one writable device `int`. Pass null for `VL` / `VR` when the
/// corresponding `jobv*` is `NOVECTOR`.
pub fn cusolverDnXgeev(
handle: cusolverDnHandle_t,
params: cusolverDnParams_t,
jobvl: cusolverEigMode_t,
jobvr: cusolverEigMode_t,
n: i64,
data_type_a: cudaDataType,
a: *mut c_void,
lda: i64,
data_type_w: cudaDataType,
w: *mut c_void,
data_type_vl: cudaDataType,
vl: *mut c_void,
ldvl: i64,
data_type_vr: cudaDataType,
vr: *mut c_void,
ldvr: i64,
compute_type: cudaDataType,
workspace_on_device: *mut c_void,
workspace_in_bytes_on_device: usize,
workspace_on_host: *mut c_void,
workspace_in_bytes_on_host: usize,
info: *mut i32,
) -> i32;
// ----- Batched QR: cublas*geqrfBatched (f32 / f64) -------------------
//
// NOTE: Despite belonging to the "linalg" family conceptually, the
// batched-QR factorization is implemented in **cuBLAS**, not cuSOLVER.
// cuSOLVER-Dn has no batched-geqrf entry point (only the non-batched
// `cusolverDn<t>geqrf`). cuBLAS's variant is workspace-free (cuBLAS
// allocates internally); it takes a *device-resident array of device
// pointers* (`Aarray[]`, `TauArray[]`) — the plan layer builds this
// array per-launch in caller-provided workspace.
/// `cublasSgeqrfBatched` — batched QR factorization (single precision).
/// Each `Aarray[b]` is overwritten in place with the `geqrf`-packed
/// `R` (upper) + Householder reflectors (strict lower);
/// `TauArray[b]` receives the Householder scalars.
///
/// # Safety
/// All pointers are device-resident. `Aarray` / `TauArray` are device
/// arrays of device pointers (length `batch_size`). `info` is a single
/// host `i32` indicating non-batched argument-validity (cuBLAS-batched
/// QR contract differs from cuSOLVER: it returns a single info, not
/// a per-slot array).
pub fn cublasSgeqrfBatched(
handle: cublasHandle_t,
m: i32,
n: i32,
a_array: *mut *mut f32,
lda: i32,
tau_array: *mut *mut f32,
info: *mut i32,
batch_size: i32,
) -> i32;
/// `cublasDgeqrfBatched`. f64 analogue.
///
/// # Safety
/// Same as the f32 variant with `f64` storage.
pub fn cublasDgeqrfBatched(
handle: cublasHandle_t,
m: i32,
n: i32,
a_array: *mut *mut f64,
lda: i32,
tau_array: *mut *mut f64,
info: *mut i32,
batch_size: i32,
) -> i32;
/// `cublasCgeqrfBatched`. Complex32 analogue. `tau_array[b]` is
/// `cuComplex` (NOT real-typed even though tau is real-magnitude for
/// real Householder — cuBLAS uses complex tau across the complex
/// family so the same `apply` routines can dispatch uniformly).
///
/// # Safety
/// Same as the f32 variant with `cuComplex` storage.
pub fn cublasCgeqrfBatched(
handle: cublasHandle_t,
m: i32,
n: i32,
a_array: *mut *mut cuComplex,
lda: i32,
tau_array: *mut *mut cuComplex,
info: *mut i32,
batch_size: i32,
) -> i32;
/// `cublasZgeqrfBatched`. Complex64 analogue.
///
/// # Safety
/// Same as the f32 variant with `cuDoubleComplex` storage.
pub fn cublasZgeqrfBatched(
handle: cublasHandle_t,
m: i32,
n: i32,
a_array: *mut *mut cuDoubleComplex,
lda: i32,
tau_array: *mut *mut cuDoubleComplex,
info: *mut i32,
batch_size: i32,
) -> i32;
// ----- cuBLAS handle lifecycle ---------------------------------------
/// `cublasCreate_v2` — create a cuBLAS handle.
pub fn cublasCreate_v2(handle: *mut cublasHandle_t) -> i32;
/// `cublasDestroy_v2` — destroy a cuBLAS handle.
pub fn cublasDestroy_v2(handle: cublasHandle_t) -> i32;
/// `cublasSetStream_v2` — bind a CUDA stream to the cuBLAS handle.
pub fn cublasSetStream_v2(handle: cublasHandle_t, stream: *mut c_void) -> i32;
// ----- Strided-batched GEMM: cublas{S,D}gemmStridedBatched (f32 / f64) -----
//
// Single-launch batched GEMM where each batch slot has identical
// shape `(m, n, k)` but its operand pointers are reached by adding
// a fixed `stride{A,B,C}` (in *element* counts) to the base pointer.
// Used by the WY-blocked batched-`ormqr` plan (Milestone 6.17) to
// apply each block reflector via three GEMMs per block: V^T·C, T·W,
// and the rank-`nb` update C -= V·W.
//
// `alpha` / `beta` are host pointers (cuBLAS default pointer-mode).
/// `cublasSgemmStridedBatched` — single-precision strided-batched
/// matrix-matrix multiply. Each slot computes
/// `C[i] := α · op(A[i]) · op(B[i]) + β · C[i]` where `A[i]`,
/// `B[i]`, `C[i]` are reached by stepping `stride{A,B,C}` element
/// counts from the respective base pointers.
///
/// # Safety
/// `handle` is a live cuBLAS handle bound to the desired stream.
/// `alpha` / `beta` are host pointers to one `f32`. `a`, `b`, `c`
/// are device pointers. Strides are in `f32` element counts (not
/// bytes).
pub fn cublasSgemmStridedBatched(
handle: cublasHandle_t,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: *const f32,
a: *const f32,
lda: i32,
stride_a: i64,
b: *const f32,
ldb: i32,
stride_b: i64,
beta: *const f32,
c: *mut f32,
ldc: i32,
stride_c: i64,
batch_count: i32,
) -> i32;
/// `cublasDgemmStridedBatched` — double-precision strided-batched
/// matrix-matrix multiply. f64 analogue of [`cublasSgemmStridedBatched`].
///
/// # Safety
/// Same as the f32 variant with `f64` storage; strides are in `f64`
/// element counts.
pub fn cublasDgemmStridedBatched(
handle: cublasHandle_t,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: *const f64,
a: *const f64,
lda: i32,
stride_a: i64,
b: *const f64,
ldb: i32,
stride_b: i64,
beta: *const f64,
c: *mut f64,
ldc: i32,
stride_c: i64,
batch_count: i32,
) -> i32;
/// `cublasCgemmStridedBatched` — single-precision complex strided-
/// batched matrix-matrix multiply. `Complex32` (== `cuComplex` ==
/// `cuFloatComplex`) analogue of [`cublasSgemmStridedBatched`]. Used
/// by the WY-blocked batched-`unmqr` plan ([`crate`]'s
/// `BatchedOrmqrWyPlan<Complex32>`) — `transa = CUBLAS_OP_C` selects
/// `V^H` for the first GEMM and `T^H` for the second GEMM when
/// applying `Q^H`.
///
/// # Safety
/// Same shape as [`cublasSgemmStridedBatched`] with `cuComplex`
/// storage; strides are in `cuComplex` element counts (8 bytes).
pub fn cublasCgemmStridedBatched(
handle: cublasHandle_t,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: *const cuComplex,
a: *const cuComplex,
lda: i32,
stride_a: i64,
b: *const cuComplex,
ldb: i32,
stride_b: i64,
beta: *const cuComplex,
c: *mut cuComplex,
ldc: i32,
stride_c: i64,
batch_count: i32,
) -> i32;
/// `cublasZgemmStridedBatched` — double-precision complex strided-
/// batched matrix-matrix multiply. `Complex64` analogue of
/// [`cublasCgemmStridedBatched`].
///
/// # Safety
/// Same shape with `cuDoubleComplex` storage; strides are in
/// `cuDoubleComplex` element counts (16 bytes).
pub fn cublasZgemmStridedBatched(
handle: cublasHandle_t,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: *const cuDoubleComplex,
a: *const cuDoubleComplex,
lda: i32,
stride_a: i64,
b: *const cuDoubleComplex,
ldb: i32,
stride_b: i64,
beta: *const cuDoubleComplex,
c: *mut cuDoubleComplex,
ldc: i32,
stride_c: i64,
batch_count: i32,
) -> i32;
// ----- cublasGemmEx (Phase 47) --------------------------------------
//
// Mixed-precision GEMM with explicit `cudaDataType` tags + a separate
// `cublasComputeType_t` accumulator. Needed by the Phase 47 FLCE
// path for f16/bf16 GEMM with fp32 accumulator + arbitrary
// transa/transb. `GemmPlan` doesn't expose col-major-A layout that
// the grad_weight accumulating GEMM requires, so we drop straight
// to cuBLAS here.
/// `cublasGemmEx` — mixed-precision GEMM with explicit dtype tags.
///
/// # Safety
/// `handle` is a live cuBLAS handle bound to the desired stream.
/// `alpha`/`beta` are host pointers; their type must match
/// `compute_type`. `a`/`b`/`c` are device pointers; their element
/// type must match the respective `cudaDataType` tags.
pub fn cublasGemmEx(
handle: cublasHandle_t,
transa: i32, transb: i32,
m: i32, n: i32, k: i32,
alpha: *const c_void,
a: *const c_void, a_type: i32, lda: i32,
b: *const c_void, b_type: i32, ldb: i32,
beta: *const c_void,
c: *mut c_void, c_type: i32, ldc: i32,
compute_type: i32, algo: i32,
) -> i32;
/// `cublasGemmStridedBatchedEx` — mixed-precision strided-batched
/// GEMM with explicit dtype tags (Phase 74). The `Ex` sibling of
/// [`cublasSgemmStridedBatched`]: each batch slot `i` computes
/// `C[i] := α · op(A[i]) · op(B[i]) + β · C[i]` where the slot-`i`
/// operand is reached by adding `i * stride_*` (in **elements**) to
/// the base pointer. `stride_a` / `stride_b` may be `0` to broadcast
/// one matrix across all slots; `stride_c` must step disjoint
/// output regions.
///
/// # Safety
/// `handle` is a live cuBLAS handle bound to the desired stream.
/// `alpha`/`beta` are host pointers; their type must match
/// `compute_type`. `a`/`b`/`c` are device pointers; their element
/// type must match the respective `cudaDataType` tags.
pub fn cublasGemmStridedBatchedEx(
handle: cublasHandle_t,
transa: i32, transb: i32,
m: i32, n: i32, k: i32,
alpha: *const c_void,
a: *const c_void, a_type: i32, lda: i32, stride_a: i64,
b: *const c_void, b_type: i32, ldb: i32, stride_b: i64,
beta: *const c_void,
c: *mut c_void, c_type: i32, ldc: i32, stride_c: i64,
batch_count: i32,
compute_type: i32, algo: i32,
) -> i32;
// ----- Triangular solve: cublas{S,D}trsm (f32 / f64) -----------------
//
// `trsm` solves one of the matrix equations
// op(A) · X = α · B (side = LEFT)
// X · op(A) = α · B (side = RIGHT)
// for `X`, overwriting `B` in place. `A` is triangular (upper or
// lower per `uplo`); `op(A)` is `A`, `A^T`, or `A^H` per `trans`.
// `alpha` is a host pointer (cuBLAS default pointer-mode).
//
// The LstSq QR-fallback path uses
// side=LEFT, uplo=UPPER, trans=N, diag=NON_UNIT, α=1
// to back-substitute `R · X = Q^T · B` (with `R` the top-left
// `N × N` upper triangle of the post-`geqrf` packed `A`).
/// `cublasStrsm` — single-precision triangular solve.
///
/// # Safety
/// `handle` is a live cuBLAS handle bound to the desired stream.
/// `alpha` is a host pointer to one `f32`. `a` is a device pointer
/// to at least `lda · k` floats (`k = m` for LEFT, `n` for RIGHT)
/// of which only the requested triangle is read. `b` is device-
/// resident `[ldb · n]` and is overwritten with the solution `X`.
pub fn cublasStrsm(
handle: cublasHandle_t,
side: i32,
uplo: cublasFillMode_t,
trans: i32,
diag: cublasDiagType_t,
m: i32,
n: i32,
alpha: *const f32,
a: *const f32,
lda: i32,
b: *mut f32,
ldb: i32,
) -> i32;
/// `cublasDtrsm` — double-precision triangular solve. f64 analogue
/// of [`cublasStrsm`].
///
/// # Safety
/// Same as the f32 variant with `f64` storage.
pub fn cublasDtrsm(
handle: cublasHandle_t,
side: i32,
uplo: cublasFillMode_t,
trans: i32,
diag: cublasDiagType_t,
m: i32,
n: i32,
alpha: *const f64,
a: *const f64,
lda: i32,
b: *mut f64,
ldb: i32,
) -> i32;
// ----- Batched SVD (Jacobi): gesvdjBatched (f32 / f64) ---------------
//
// Jacobi-method batched SVD. Requires a `gesvdjInfo_t` parameter
// object created via `cusolverDnCreateGesvdjInfo` / destroyed via
// `cusolverDnDestroyGesvdjInfo`. The plan layer creates one on first
// run (same lifetime pattern as the cuSOLVER handle itself).
/// `cusolverDnCreateGesvdjInfo` — allocate a Jacobi-SVD params object
/// with cuSOLVER's defaults (`tol = 1e-7` for f32 / `1e-12` for f64,
/// `max_sweeps = 100`, `sort_eig = 1`).
///
/// # Safety
/// `info` must point to writable storage for one `gesvdjInfo_t`.
pub fn cusolverDnCreateGesvdjInfo(info: *mut gesvdjInfo_t) -> i32;
/// `cusolverDnDestroyGesvdjInfo`. Returns 0 on success.
///
/// # Safety
/// `info` must be a valid `gesvdjInfo_t` returned by
/// `cusolverDnCreateGesvdjInfo` that has not been previously destroyed.
pub fn cusolverDnDestroyGesvdjInfo(info: gesvdjInfo_t) -> i32;
/// `cusolverDnSgesvdjBatched_bufferSize`. `jobz` is `0` (no vectors)
/// or `1` (compute U / V). For batched, each matrix in `A` is
/// independently SVD'd; outputs are packed `[batch * m * m]` etc.
///
/// # Safety
/// `handle` live; `params` a valid `gesvdjInfo_t`; `lwork` writable
/// `int`.
pub fn cusolverDnSgesvdjBatched_bufferSize(
handle: cusolverDnHandle_t,
jobz: i32,
m: i32,
n: i32,
a: *const f32,
lda: i32,
s: *const f32,
u: *const f32,
ldu: i32,
v: *const f32,
ldv: i32,
lwork: *mut i32,
params: gesvdjInfo_t,
batch_size: i32,
) -> i32;
/// `cusolverDnSgesvdjBatched` — batched Jacobi SVD `A = U · diag(S) · V^T`
/// (single precision). Each matrix is square `[m, m]` (cuSOLVER's
/// Jacobi-batched API requires square input; thin rectangular is
/// achievable via `gesvdaStridedBatched` — deferred). The plan
/// surfaces `V` (not `V^T`); callers apply the transpose if needed.
///
/// # Safety
/// All pointers are device-resident; `S ≥ batch*min(m,n) * sizeof(T)`;
/// `U ≥ batch*m*m * sizeof(T)`; `V ≥ batch*n*n * sizeof(T)`;
/// `workspace ≥ lwork * sizeof(T)`; `dev_info ≥ batch * sizeof(int)`.
pub fn cusolverDnSgesvdjBatched(
handle: cusolverDnHandle_t,
jobz: i32,
m: i32,
n: i32,
a: *mut f32,
lda: i32,
s: *mut f32,
u: *mut f32,
ldu: i32,
v: *mut f32,
ldv: i32,
workspace: *mut f32,
lwork: i32,
info: *mut i32,
params: gesvdjInfo_t,
batch_size: i32,
) -> i32;
/// `cusolverDnDgesvdjBatched_bufferSize`. f64 analogue.
///
/// # Safety
/// Same as the f32 variant with `f64` storage.
pub fn cusolverDnDgesvdjBatched_bufferSize(
handle: cusolverDnHandle_t,
jobz: i32,
m: i32,
n: i32,
a: *const f64,
lda: i32,
s: *const f64,
u: *const f64,
ldu: i32,
v: *const f64,
ldv: i32,
lwork: *mut i32,
params: gesvdjInfo_t,
batch_size: i32,
) -> i32;
/// `cusolverDnDgesvdjBatched`. f64 analogue.
///
/// # Safety
/// Same as the f32 variant with `f64` storage.
pub fn cusolverDnDgesvdjBatched(
handle: cusolverDnHandle_t,
jobz: i32,
m: i32,
n: i32,
a: *mut f64,
lda: i32,
s: *mut f64,
u: *mut f64,
ldu: i32,
v: *mut f64,
ldv: i32,
workspace: *mut f64,
lwork: i32,
info: *mut i32,
params: gesvdjInfo_t,
batch_size: i32,
) -> i32;
// ----- Least-squares: gels (f32 / f64) -------------------------------
//
// Mixed-precision iterative-refinement `_gels` routine. The single-
// precision entry is `cusolverDnSSgels` (S-input, S-compute) and the
// double-precision entry is `cusolverDnDDgels` (D-input, D-compute).
// Other letter combinations (SH, SB, DS, DH, DB) exist for mixed-
// precision strategies but the plan layer surfaces only the
// same-precision variants today. The routine returns the iteration
// count via `niters`; if it failed to converge the safe-plan layer
// reports a non-convergence error (QR-fallback is deferred).
/// `cusolverDnSSgels_bufferSize` — query bytes (the routine's
/// workspace is supplied as a raw byte buffer, not a typed element
/// count, distinct from the `*_bufferSize` entries above).
///
/// # Safety
/// `handle` live; `lwork_bytes` writable `size_t`.
pub fn cusolverDnSSgels_bufferSize(
handle: cusolverDnHandle_t,
m: i32,
n: i32,
nrhs: i32,
a: *mut f32,
lda: i32,
b: *mut f32,
ldb: i32,
x: *mut f32,
ldx: i32,
workspace: *mut c_void,
lwork_bytes: *mut usize,
) -> i32;
/// `cusolverDnSSgels` — least-squares solve `min ||A·x - b||²` for
/// `m ≥ n` full-rank `A`. Iterative refinement; returns `niters` ≥ 0
/// on convergence, `-N` on fallback-needed. Single precision.
///
/// # Safety
/// All pointers reference device memory; `x` is the device-resident
/// solution buffer of length `n * nrhs * sizeof(T)`. `workspace_bytes`
/// from the matching `_bufferSize`.
pub fn cusolverDnSSgels(
handle: cusolverDnHandle_t,
m: i32,
n: i32,
nrhs: i32,
a: *mut f32,
lda: i32,
b: *mut f32,
ldb: i32,
x: *mut f32,
ldx: i32,
workspace: *mut c_void,
lwork_bytes: usize,
niters: *mut i32,
dev_info: *mut i32,
) -> i32;
/// `cusolverDnDDgels_bufferSize`. f64 analogue.
///
/// # Safety
/// Same as the f32 variant with `f64` storage.
pub fn cusolverDnDDgels_bufferSize(
handle: cusolverDnHandle_t,
m: i32,
n: i32,
nrhs: i32,
a: *mut f64,
lda: i32,
b: *mut f64,
ldb: i32,
x: *mut f64,
ldx: i32,
workspace: *mut c_void,
lwork_bytes: *mut usize,
) -> i32;
/// `cusolverDnDDgels`. f64 analogue.
///
/// # Safety
/// Same as the f32 variant with `f64` storage.
pub fn cusolverDnDDgels(
handle: cusolverDnHandle_t,
m: i32,
n: i32,
nrhs: i32,
a: *mut f64,
lda: i32,
b: *mut f64,
ldb: i32,
x: *mut f64,
ldx: i32,
workspace: *mut c_void,
lwork_bytes: usize,
niters: *mut i32,
dev_info: *mut i32,
) -> i32;
// ----- Rectangular batched SVD: gesvdaStridedBatched (f32 / f64) ------
//
// Approximate-SVD batched API that, unlike `gesvdjBatched`, accepts
// **rectangular** `[m, n]` matrices and uses **element-strides** between
// batch slots (not pointer arrays). Per-slot residual Frobenius norms
// are written to a **host** array `h_R_nrmF`.
//
// Gotcha: the `lwork` returned by `_bufferSize` (and accepted by the
// exec call) is measured in **elements**, not bytes — multiply by
// `sizeof(T)` to get the byte count for the `Workspace` buffer.
//
// The `rank` parameter (≤ `min(m, n)`) selects the number of singular
// triplets to compute; pass `min(m, n)` for the full thin SVD. Outputs
// are `S: [batch, rank]`, `U: [batch, m, rank]`, `V: [batch, n, rank]`
// (column-major per slot). cuSOLVER returns `V` directly (not `V^T`).
/// `cusolverDnSgesvdaStridedBatched_bufferSize` — query the device
/// workspace size (in **elements**, multiply by `sizeof(f32)` for
/// bytes) for the f32 rectangular-batched approximate-SVD.
///
/// # Safety
/// `handle` live; `lwork` writable `int`. Pointer args may be null
/// (they're only inspected for shape inference).
pub fn cusolverDnSgesvdaStridedBatched_bufferSize(
handle: cusolverDnHandle_t,
jobz: i32,
rank: i32,
m: i32,
n: i32,
a: *const f32,
lda: i32,
stride_a: i64,
s: *const f32,
stride_s: i64,
u: *const f32,
ldu: i32,
stride_u: i64,
v: *const f32,
ldv: i32,
stride_v: i64,
lwork: *mut i32,
batch_size: i32,
) -> i32;
/// `cusolverDnSgesvdaStridedBatched` — f32 rectangular-batched
/// approximate-SVD. Each batch slot factors a `[m, n]` matrix into
/// `U: [m, rank]`, `S: [rank]`, `V: [n, rank]` (column-major;
/// cuSOLVER returns `V`, not `V^T`). The host array `h_R_nrmF` (size
/// `batch_size`) receives per-slot residual Frobenius norms.
///
/// # Safety
/// Device pointers `a`, `s`, `u` (when `jobz == VECTOR`), `v` (when
/// `jobz == VECTOR`), `work`, `info` must be valid; `h_R_nrmF` is a
/// **host** buffer of `batch_size` `f64`s; `lwork` from the matching
/// `_bufferSize`.
pub fn cusolverDnSgesvdaStridedBatched(
handle: cusolverDnHandle_t,
jobz: i32,
rank: i32,
m: i32,
n: i32,
a: *const f32,
lda: i32,
stride_a: i64,
s: *mut f32,
stride_s: i64,
u: *mut f32,
ldu: i32,
stride_u: i64,
v: *mut f32,
ldv: i32,
stride_v: i64,
work: *mut f32,
lwork: i32,
info: *mut i32,
h_r_nrm_f: *mut f64,
batch_size: i32,
) -> i32;
/// `cusolverDnDgesvdaStridedBatched_bufferSize`. f64 analogue.
///
/// # Safety
/// Same as the f32 variant with `f64` storage.
pub fn cusolverDnDgesvdaStridedBatched_bufferSize(
handle: cusolverDnHandle_t,
jobz: i32,
rank: i32,
m: i32,
n: i32,
a: *const f64,
lda: i32,
stride_a: i64,
s: *const f64,
stride_s: i64,
u: *const f64,
ldu: i32,
stride_u: i64,
v: *const f64,
ldv: i32,
stride_v: i64,
lwork: *mut i32,
batch_size: i32,
) -> i32;
/// `cusolverDnDgesvdaStridedBatched`. f64 analogue.
///
/// # Safety
/// Same as the f32 variant with `f64` storage.
pub fn cusolverDnDgesvdaStridedBatched(
handle: cusolverDnHandle_t,
jobz: i32,
rank: i32,
m: i32,
n: i32,
a: *const f64,
lda: i32,
stride_a: i64,
s: *mut f64,
stride_s: i64,
u: *mut f64,
ldu: i32,
stride_u: i64,
v: *mut f64,
ldv: i32,
stride_v: i64,
work: *mut f64,
lwork: i32,
info: *mut i32,
h_r_nrm_f: *mut f64,
batch_size: i32,
) -> i32;
}
/// `CUSOLVER_EIG_MODE_NOVECTOR` — `gesvdjBatched` `jobz` value for
/// computing singular values only (skip U / V).
pub const CUSOLVER_EIG_MODE_NOVECTOR: i32 = 0;
/// `CUSOLVER_EIG_MODE_VECTOR` — `gesvdjBatched` `jobz` value for
/// computing both singular values and singular vectors.
pub const CUSOLVER_EIG_MODE_VECTOR: i32 = 1;
// ============================================================================
// cuFFT — Milestone 6.4 Fast Fourier Transforms
// ============================================================================
//
// Host-API cuFFT bindings for the four canonical PyTorch / JAX 1-D FFT
// ops: FFT (C2C forward), IFFT (C2C inverse), RFFT (R2C forward), IRFFT
// (C2R inverse). Plus single/double precision sibling entry points.
//
// f32 (single) + f64 (double) only — cuFFT's main API does not expose
// f16 / bf16 for native FFTs. Callers needing reduced precision must
// cast on either side. Inverse transforms are *unnormalized* — cuFFT
// returns N · IFFT(FFT(x)); the safe-plan layer multiplies by 1/N after
// each inverse exec to match PyTorch's `norm="backward"` default.
//
// Linkage: `cargo:rustc-link-lib=dylib=cufft` (added in build.rs). On
// Linux this resolves to `libcufft.so`; on Windows to `cufft64_*.dll`
// (loaded from `CUDA_PATH\bin`).
//
// Note: cuFFT handles are **integer IDs**, not pointers. This is unusual
// among CUDA libraries (cuSOLVER / cuBLAS / cuRAND all use opaque
// pointer handles) — we represent the handle as `i32` to match the
// upstream C ABI exactly. A sentinel value of `-1` marks "not yet
// created" at the plan layer.
/// Opaque cuFFT plan handle. Unusually for CUDA libraries this is an
/// **integer ID** (`int`), not a pointer. A value of `-1` is reserved
/// at the safe-plan layer as the "not yet created" sentinel — cuFFT
/// itself returns small non-negative integers for live handles.
#[allow(non_camel_case_types)]
pub type cufftHandle = i32;
/// cuFFT result code type. `CUFFT_SUCCESS = 0`. Any non-zero return is
/// mapped to a negative status at the safe-plan layer for distinct
/// error reporting.
#[allow(non_camel_case_types)]
pub type cufftResult = i32;
/// `CUFFT_SUCCESS` — the only success code.
pub const CUFFT_SUCCESS: i32 = 0;
/// cuFFT plan type: real-to-complex (single precision). Output buffer
/// size is `N/2 + 1` complex cells for an `N`-long real input
/// (Hermitian symmetry).
pub const CUFFT_R2C: i32 = 0x2a;
/// cuFFT plan type: complex-to-real (single precision). Input is
/// `N/2 + 1` complex cells (Hermitian-half), output is `N` real cells.
pub const CUFFT_C2R: i32 = 0x2c;
/// cuFFT plan type: complex-to-complex (single precision). Direction is
/// supplied to `cufftExecC2C`.
pub const CUFFT_C2C: i32 = 0x29;
/// cuFFT plan type: double-precision real-to-complex.
pub const CUFFT_D2Z: i32 = 0x6a;
/// cuFFT plan type: double-precision complex-to-real.
pub const CUFFT_Z2D: i32 = 0x6c;
/// cuFFT plan type: double-precision complex-to-complex.
pub const CUFFT_Z2Z: i32 = 0x69;
/// Forward FFT direction tag for `cufftExecC2C` / `cufftExecZ2Z`.
/// cuFFT's forward transform is unnormalized.
pub const CUFFT_FORWARD: i32 = -1;
/// Inverse FFT direction tag for `cufftExecC2C` / `cufftExecZ2Z`.
/// cuFFT's inverse transform is **also unnormalized** — the safe-plan
/// layer multiplies the output by `1/N` after exec to match PyTorch's
/// `norm="backward"` (forward unnormalized, inverse normalized by N)
/// convention.
pub const CUFFT_INVERSE: i32 = 1;
/// Single-precision complex element layout. Interleaved real/imag
/// pairs — `#[repr(C)]` matches NVIDIA's `cufftComplex` struct exactly
/// (which is itself an alias for `float2` in `<vector_types.h>`). The
/// plan layer pairs this with the [`crate`]-level `Complex32` newtype.
#[repr(C)]
#[derive(Copy, Clone, Debug, Default, PartialEq)]
#[allow(non_camel_case_types)]
pub struct cufftComplex {
/// Real component.
pub x: f32,
/// Imaginary component.
pub y: f32,
}
/// Double-precision complex element layout. ABI-compatible with cuFFT's
/// `cufftDoubleComplex` (alias for `double2`).
#[repr(C)]
#[derive(Copy, Clone, Debug, Default, PartialEq)]
#[allow(non_camel_case_types)]
pub struct cufftDoubleComplex {
/// Real component.
pub x: f64,
/// Imaginary component.
pub y: f64,
}
unsafe extern "C" {
// ----- plan lifecycle ------------------------------------------------
/// `cufftPlan1d(plan, nx, type, batch)`. Allocates a 1-D plan
/// (single FFT of length `nx`, or `batch` independent FFTs each of
/// length `nx` laid out contiguously). cuFFT's plan struct owns its
/// own workspace internally — no caller-supplied workspace is
/// required for the basic 1-D APIs.
///
/// # Safety
/// `plan` must point to writable storage for one `cufftHandle`. The
/// underlying CUDA context must be live.
pub fn cufftPlan1d(plan: *mut cufftHandle, nx: i32, fft_type: i32, batch: i32) -> i32;
/// `cufftDestroy(plan)`. Frees the plan's internal workspace.
///
/// # Safety
/// `plan` must be a valid handle returned by `cufftPlan1d` (or any
/// other plan-creation entry) that has not been destroyed.
pub fn cufftDestroy(plan: cufftHandle) -> i32;
/// `cufftSetStream(plan, stream)`. Binds subsequent exec calls on
/// this plan to the given CUDA stream. Returns 0 on success.
///
/// # Safety
/// `plan` must be a live cuFFT handle; `stream` must be a valid
/// CUDA stream in the current context (or null for the default).
pub fn cufftSetStream(plan: cufftHandle, stream: *mut c_void) -> i32;
// ----- exec entry points (single precision) --------------------------
/// `cufftExecC2C(plan, idata, odata, direction)` — complex-to-
/// complex single-precision exec. `direction` is `CUFFT_FORWARD`
/// or `CUFFT_INVERSE`. Inverse is unnormalized.
///
/// # Safety
/// `plan` live, `idata` / `odata` device pointers to at least
/// `nx * batch` `cufftComplex` cells each (in-place exec when
/// `idata == odata` is allowed by cuFFT).
pub fn cufftExecC2C(
plan: cufftHandle,
idata: *mut cufftComplex,
odata: *mut cufftComplex,
direction: i32,
) -> i32;
/// `cufftExecR2C(plan, idata, odata)` — real-to-complex single
/// precision. Input length is `nx`, output length is `nx/2 + 1`
/// (Hermitian-half).
///
/// # Safety
/// `plan` live, `idata` to `nx * batch` `float` cells, `odata` to
/// `(nx/2 + 1) * batch` `cufftComplex` cells.
pub fn cufftExecR2C(plan: cufftHandle, idata: *mut f32, odata: *mut cufftComplex) -> i32;
/// `cufftExecC2R(plan, idata, odata)` — complex-to-real single
/// precision. Input length is `nx/2 + 1`, output length is `nx`.
/// Unnormalized — caller must scale by `1/nx`.
///
/// # Safety
/// `plan` live, `idata` to `(nx/2 + 1) * batch` `cufftComplex`,
/// `odata` to `nx * batch` `float`.
pub fn cufftExecC2R(plan: cufftHandle, idata: *mut cufftComplex, odata: *mut f32) -> i32;
// ----- exec entry points (double precision) --------------------------
/// `cufftExecZ2Z(plan, idata, odata, direction)` — complex-to-
/// complex double precision. Same semantics as `cufftExecC2C`.
///
/// # Safety
/// Same as `cufftExecC2C` with `cufftDoubleComplex` cells.
pub fn cufftExecZ2Z(
plan: cufftHandle,
idata: *mut cufftDoubleComplex,
odata: *mut cufftDoubleComplex,
direction: i32,
) -> i32;
/// `cufftExecD2Z(plan, idata, odata)` — real-to-complex double
/// precision. Same semantics as `cufftExecR2C`.
///
/// # Safety
/// `plan` live, `idata` to `nx * batch` `double` cells, `odata` to
/// `(nx/2 + 1) * batch` `cufftDoubleComplex` cells.
pub fn cufftExecD2Z(plan: cufftHandle, idata: *mut f64, odata: *mut cufftDoubleComplex)
-> i32;
/// `cufftExecZ2D(plan, idata, odata)` — complex-to-real double
/// precision. Unnormalized.
///
/// # Safety
/// `plan` live, `idata` to `(nx/2 + 1) * batch` `cufftDoubleComplex`,
/// `odata` to `nx * batch` `double`.
pub fn cufftExecZ2D(plan: cufftHandle, idata: *mut cufftDoubleComplex, odata: *mut f64)
-> i32;
// ----- plan lifecycle (multi-dimensional / advanced layout) ----------
/// `cufftPlanMany(plan, rank, n, inembed, istride, idist,
/// onembed, ostride, odist, type, batch)`.
///
/// Allocates a `rank`-D plan covering `batch` independent transforms.
/// `n` points to a `rank`-element array of per-axis lengths
/// (`n[0]` is the slowest-varying transform axis, `n[rank-1]` the
/// fastest). `inembed` / `onembed` describe the stride layout in
/// memory; passing `core::ptr::null_mut()` for both selects cuFFT's
/// "tight default layout" — each batched transform occupies a
/// contiguous block of `n[0] * n[1] * ... * n[rank-1]` elements
/// (the case the ND wrappers in `baracuda-kernels` use).
///
/// `istride` / `ostride` are element strides between consecutive
/// elements within a single transform (use `1` for the default
/// layout). `idist` / `odist` are batch strides — the element
/// offset from one transform's first element to the next. For the
/// default layout pass `idist = odist = product(n)` (R2C / C2R
/// follow cuFFT's Hermitian-half rules — the last-axis extent
/// halves on the complex side).
///
/// Returns 0 (`CUFFT_SUCCESS`) on success.
///
/// # Safety
/// `plan` must point to writable storage for one `cufftHandle`.
/// `n` must point to `rank` `i32` cells; `inembed` / `onembed` may
/// be null (default layout) or point to `rank` `i32` cells each.
/// The CUDA context must be live.
pub fn cufftPlanMany(
plan: *mut cufftHandle,
rank: i32,
n: *mut i32,
inembed: *mut i32,
istride: i32,
idist: i32,
onembed: *mut i32,
ostride: i32,
odist: i32,
fft_type: i32,
batch: i32,
) -> i32;
}
// ============================================================================
// cuFFT bespoke kernels — fftshift / ifftshift + in-place scale-by-1/N.
// ============================================================================
//
// Two bespoke kernel families used by the cuFFT wrap:
//
// 1. **fftshift / ifftshift** — index permutation along the last axis
// of a `[batch, n]` tensor. Templated on element width (4 bytes for
// f32, 8 bytes for f64 / Complex32, 16 bytes for Complex64). cuFFT
// has no native fftshift — these complete the `torch.fft` family.
//
// 2. **scale_inplace_{c32,c64,f32,f64}** — multiply an in-place buffer
// by a scalar. Used to apply the `1/N` normalization to inverse
// transforms (cuFFT returns N · IFFT(x); PyTorch's `norm="backward"`
// convention wants IFFT(x)).
//
// ABI mirrors the elementwise / random kernel families:
// 0 success, 1 misaligned, 2 invalid problem, 3 unsupported,
// 4 workspace too small, 5 internal launch failure.
unsafe extern "C" {
/// `fftshift` along the last axis of a `[batch, n]` tensor:
/// `y[b, i] = x[b, (i + n/2) % n]`. Element-width specialization
/// (4 bytes per element) — used for `Bool` / `f32` / packed-Bool
/// shifts; the same kernel re-instantiated at 8 / 16 bytes covers
/// `f64` / `Complex32` and `Complex64`.
///
/// ABI: `(batch, n, x, y, ws, ws_bytes, stream) -> i32`.
///
/// # Safety
/// `x` / `y` must each point to at least `batch * n` cells of the
/// kernel's element width. `stream` must be a live CUDA stream in
/// the current context.
pub fn baracuda_kernels_fftshift_4_run(
batch: i64,
n: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fftshift_4_can_implement` (baracuda kernels fftshift 4 can implement).
pub fn baracuda_kernels_fftshift_4_can_implement(
batch: i64,
n: i32,
x: *const c_void,
y: *const c_void,
) -> i32;
/// 8-byte-element `fftshift` (covers `f64` and `Complex32`).
///
/// # Safety
/// Same as `baracuda_kernels_fftshift_4_run` with 8-byte cells.
pub fn baracuda_kernels_fftshift_8_run(
batch: i64,
n: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fftshift_8_can_implement` (baracuda kernels fftshift 8 can implement).
pub fn baracuda_kernels_fftshift_8_can_implement(
batch: i64,
n: i32,
x: *const c_void,
y: *const c_void,
) -> i32;
/// 16-byte-element `fftshift` (covers `Complex64`).
///
/// # Safety
/// Same as `baracuda_kernels_fftshift_4_run` with 16-byte cells.
pub fn baracuda_kernels_fftshift_16_run(
batch: i64,
n: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fftshift_16_can_implement` (baracuda kernels fftshift 16 can implement).
pub fn baracuda_kernels_fftshift_16_can_implement(
batch: i64,
n: i32,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Inverse `fftshift` along the last axis of a `[batch, n]` tensor:
/// `y[b, i] = x[b, (i + (n + 1) / 2) % n]`. Differs from `fftshift`
/// only for odd `n`; for even `n` the two are identical (each
/// permutation is self-inverse). 4-byte cells.
///
/// # Safety
/// Same as `baracuda_kernels_fftshift_4_run`.
pub fn baracuda_kernels_ifftshift_4_run(
batch: i64,
n: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_ifftshift_4_can_implement` (baracuda kernels ifftshift 4 can implement).
pub fn baracuda_kernels_ifftshift_4_can_implement(
batch: i64,
n: i32,
x: *const c_void,
y: *const c_void,
) -> i32;
/// 8-byte-element inverse `fftshift`.
///
/// # Safety
/// Same as `baracuda_kernels_ifftshift_4_run` with 8-byte cells.
pub fn baracuda_kernels_ifftshift_8_run(
batch: i64,
n: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_ifftshift_8_can_implement` (baracuda kernels ifftshift 8 can implement).
pub fn baracuda_kernels_ifftshift_8_can_implement(
batch: i64,
n: i32,
x: *const c_void,
y: *const c_void,
) -> i32;
/// 16-byte-element inverse `fftshift`.
///
/// # Safety
/// Same as `baracuda_kernels_ifftshift_4_run` with 16-byte cells.
pub fn baracuda_kernels_ifftshift_16_run(
batch: i64,
n: i32,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_ifftshift_16_can_implement` (baracuda kernels ifftshift 16 can implement).
pub fn baracuda_kernels_ifftshift_16_can_implement(
batch: i64,
n: i32,
x: *const c_void,
y: *const c_void,
) -> i32;
/// In-place scale of a `cufftComplex` buffer by a real scalar:
/// `y[i].x *= scale; y[i].y *= scale;`. Applied after `cufftExecC2C`
/// in the inverse direction to bake in the 1/N normalization
/// PyTorch expects.
///
/// # Safety
/// `y` must point to `numel` `cufftComplex` cells; `stream` live.
pub fn baracuda_kernels_scale_inplace_c32_run(
numel: i64,
scale: f32,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_scale_inplace_c32`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_scale_inplace_c32_can_implement(
numel: i64,
scale: f32,
y: *const c_void,
) -> i32;
/// In-place scale of a `cufftDoubleComplex` buffer by a real
/// scalar. f64 analogue of `baracuda_kernels_scale_inplace_c32_run`.
///
/// # Safety
/// `y` must point to `numel` `cufftDoubleComplex` cells.
pub fn baracuda_kernels_scale_inplace_c64_run(
numel: i64,
scale: f64,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_scale_inplace_c64`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_scale_inplace_c64_can_implement(
numel: i64,
scale: f64,
y: *const c_void,
) -> i32;
/// In-place scale of a real `f32` buffer. Used to bake the `1/N`
/// normalization into the output of `cufftExecC2R` (IRFFT).
///
/// # Safety
/// `y` must point to `numel` `f32` cells.
pub fn baracuda_kernels_scale_inplace_real_f32_run(
numel: i64,
scale: f32,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_scale_inplace_real_f32`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_scale_inplace_real_f32_can_implement(
numel: i64,
scale: f32,
y: *const c_void,
) -> i32;
/// In-place scale of a real `f64` buffer. f64 analogue.
///
/// # Safety
/// `y` must point to `numel` `f64` cells.
pub fn baracuda_kernels_scale_inplace_real_f64_run(
numel: i64,
scale: f64,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_scale_inplace_real_f64`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_scale_inplace_real_f64_can_implement(
numel: i64,
scale: f64,
y: *const c_void,
) -> i32;
/// N-D `fftshift` / `ifftshift` — single-pass general-permutation
/// kernel covering up to rank-8 tensors. The caller passes a per-
/// axis `shape`, per-axis `shift_amt` (0 for pass-through axes;
/// `n/2` for fftshift / `n - n/2` for ifftshift on shifted axes),
/// and per-axis contiguous `stride` (in elements). The same kernel
/// covers both directions — the direction lives entirely in the
/// `shift_amt` array.
///
/// 4-byte cell width (covers `f32`).
///
/// ABI: `(total, rank, shape, shift_amt, stride, x, y, ws,
/// ws_bytes, stream) -> i32`.
///
/// # Safety
/// `x` / `y` must each point to at least `total` cells of the
/// kernel's element width. `shape` / `shift_amt` / `stride` must
/// each point to at least `rank` valid entries (host memory).
/// `rank <= 8`. `stream` must be a live CUDA stream in the current
/// context.
pub fn baracuda_kernels_fftshift_nd_4_run(
total: i64,
rank: i32,
shape: *const i32,
shift_amt: *const i32,
stride: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fftshift_nd_4_can_implement` (baracuda kernels fftshift nd 4 can implement).
pub fn baracuda_kernels_fftshift_nd_4_can_implement(
total: i64,
rank: i32,
shape: *const i32,
shift_amt: *const i32,
stride: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// 8-byte-cell N-D fftshift (covers `f64` and `Complex32`).
///
/// # Safety
/// Same as `baracuda_kernels_fftshift_nd_4_run` with 8-byte cells.
pub fn baracuda_kernels_fftshift_nd_8_run(
total: i64,
rank: i32,
shape: *const i32,
shift_amt: *const i32,
stride: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fftshift_nd_8_can_implement` (baracuda kernels fftshift nd 8 can implement).
pub fn baracuda_kernels_fftshift_nd_8_can_implement(
total: i64,
rank: i32,
shape: *const i32,
shift_amt: *const i32,
stride: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// 16-byte-cell N-D fftshift (covers `Complex64`).
///
/// # Safety
/// Same as `baracuda_kernels_fftshift_nd_4_run` with 16-byte cells.
pub fn baracuda_kernels_fftshift_nd_16_run(
total: i64,
rank: i32,
shape: *const i32,
shift_amt: *const i32,
stride: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fftshift_nd_16_can_implement` (baracuda kernels fftshift nd 16 can implement).
pub fn baracuda_kernels_fftshift_nd_16_can_implement(
total: i64,
rank: i32,
shape: *const i32,
shift_amt: *const i32,
stride: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
// =============================================================================
// Milestone 6.14 — bespoke batched-`ormqr` + batched-QR dense Q/R
// materialization helpers (linalg family). cuSOLVER's `ormqr` is
// non-batched; this kernel fuses all batch slots into one launch so the
// small-matrix regime (where batched-QR is most useful) is not
// latency-bound. Scope: Side = Left, op ∈ {N, T}, dtype ∈ {f32, f64};
// Right-side + complex variants deferred.
// =============================================================================
unsafe extern "C" {
/// Batched-`ormqr`, `f32`. Applies the implicit `Q` (or `Q^T`) from a
/// `BatchedQrPlan` packed output (`A_packed [B, M, K]` column-major
/// + `tau [B, K]`) to a stack of right-hand-side matrices
/// `C [B, M, N]` in place. One CUDA block per batch slot. `side` is
/// fixed to `0` (Left) in the trailblazer; `op` is `0` (N — apply Q)
/// or `1` (T — apply Q^T). Status: 0 success, 2 invalid problem,
/// 3 unsupported (e.g. side = Right), 5 internal launch failure.
///
/// # Safety
/// `a_packed` must point to at least `batch * M * K` `f32` cells
/// (column-major); `tau` to at least `batch * K`; `c` to at least
/// `batch * M * N`. `stream` must be live.
pub fn baracuda_kernels_batched_ormqr_f32_run(
batch: i32,
m: i32,
n: i32,
k: i32,
side: i32,
op: i32,
a_packed: *const c_void,
tau: *const c_void,
c: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_batched_ormqr_f32_can_implement` (baracuda kernels batched ormqr f32 can implement).
pub fn baracuda_kernels_batched_ormqr_f32_can_implement(
batch: i32, m: i32, n: i32, k: i32, side: i32, op: i32,
) -> i32;
/// Batched-`ormqr`, `f64`. Same contract as the `f32` variant.
///
/// # Safety
/// Same as the `f32` variant with `f64` storage.
pub fn baracuda_kernels_batched_ormqr_f64_run(
batch: i32,
m: i32,
n: i32,
k: i32,
side: i32,
op: i32,
a_packed: *const c_void,
tau: *const c_void,
c: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_batched_ormqr_f64_can_implement` (baracuda kernels batched ormqr f64 can implement).
pub fn baracuda_kernels_batched_ormqr_f64_can_implement(
batch: i32, m: i32, n: i32, k: i32, side: i32, op: i32,
) -> i32;
/// Batched-`unmqr`, `Complex32`. Same shape/contract as the `f32`
/// variant but with `cuFloatComplex` storage. `op = 2` (C —
/// conjugate transpose) is supported; `op = 1` (T — plain transpose)
/// is rejected by the Rust safe layer for complex (mathematically
/// unusual for Householder).
///
/// # Safety
/// Pointer sizes counted in `cuFloatComplex` cells; layout otherwise
/// identical to the `f32` runner.
pub fn baracuda_kernels_batched_ormqr_complex32_run(
batch: i32,
m: i32,
n: i32,
k: i32,
side: i32,
op: i32,
a_packed: *const c_void,
tau: *const c_void,
c: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_batched_ormqr_complex32_can_implement` (baracuda kernels batched ormqr complex32 can implement).
pub fn baracuda_kernels_batched_ormqr_complex32_can_implement(
batch: i32, m: i32, n: i32, k: i32, side: i32, op: i32,
) -> i32;
/// Batched-`unmqr`, `Complex64`. Same as the `complex32` variant
/// with `cuDoubleComplex` storage.
///
/// # Safety
/// Same as the `complex32` variant with `cuDoubleComplex` storage.
pub fn baracuda_kernels_batched_ormqr_complex64_run(
batch: i32,
m: i32,
n: i32,
k: i32,
side: i32,
op: i32,
a_packed: *const c_void,
tau: *const c_void,
c: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_batched_ormqr_complex64_can_implement` (baracuda kernels batched ormqr complex64 can implement).
pub fn baracuda_kernels_batched_ormqr_complex64_can_implement(
batch: i32, m: i32, n: i32, k: i32, side: i32, op: i32,
) -> i32;
/// Materialize dense `R [B, K, N]` from a `geqrf`-packed
/// `A [B, M, N]` (column-major). `K = min(M, N)`. Cell `R[b, i, j]`
/// = `A[b, i, j]` if `i ≤ j`, else `0`. One CUDA block per
/// `(batch_slot, column)`. `f32`.
///
/// # Safety
/// `a_packed` ≥ `batch * M * N` `f32` cells; `r` ≥ `batch * K * N`.
pub fn baracuda_kernels_batched_qr_materialize_r_f32_run(
batch: i32,
m: i32,
n: i32,
k: i32,
a_packed: *const c_void,
r: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_batched_qr_materialize_r_f32_can_implement` (baracuda kernels batched qr materialize r f32 can implement).
pub fn baracuda_kernels_batched_qr_materialize_r_f32_can_implement(
batch: i32, m: i32, n: i32, k: i32,
) -> i32;
/// Materialize dense `R`, `f64` analogue.
///
/// # Safety
/// Same as the `f32` variant with `f64` storage.
pub fn baracuda_kernels_batched_qr_materialize_r_f64_run(
batch: i32,
m: i32,
n: i32,
k: i32,
a_packed: *const c_void,
r: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_batched_qr_materialize_r_f64_can_implement` (baracuda kernels batched qr materialize r f64 can implement).
pub fn baracuda_kernels_batched_qr_materialize_r_f64_can_implement(
batch: i32, m: i32, n: i32, k: i32,
) -> i32;
/// Stage a column-major identity `Q [B, M, M]` (one identity per
/// batch slot) into a freshly allocated buffer. Caller then chains
/// `baracuda_kernels_batched_ormqr_*_run` with `op = 0` (N) to
/// overwrite `Q` in place with the dense Q matrix from the
/// `geqrf`-packed input. `f32`.
///
/// # Safety
/// `q` must point to at least `batch * M * M` `f32` cells.
pub fn baracuda_kernels_batched_qr_materialize_identity_f32_run(
batch: i32,
m: i32,
q: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_batched_qr_materialize_identity_f32_can_implement` (baracuda kernels batched qr materialize identity f32 can implement).
pub fn baracuda_kernels_batched_qr_materialize_identity_f32_can_implement(
batch: i32, m: i32,
) -> i32;
/// Stage identity, `f64` analogue.
///
/// # Safety
/// Same as the `f32` variant with `f64` storage.
pub fn baracuda_kernels_batched_qr_materialize_identity_f64_run(
batch: i32,
m: i32,
q: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_batched_qr_materialize_identity_f64_can_implement` (baracuda kernels batched qr materialize identity f64 can implement).
pub fn baracuda_kernels_batched_qr_materialize_identity_f64_can_implement(
batch: i32, m: i32,
) -> i32;
}
// =============================================================================
// Milestone 6.17 — WY-blocked batched-`ormqr`. Companion to the GEMV-rates
// kernel above (`baracuda_kernels_batched_ormqr_*_run`). Two bespoke
// kernels (T-build + V-extract) pair with cuBLAS strided-batched GEMM at
// the safe-plan layer to lift the apply step from GEMV-rates to GEMM-
// rates. Scope: Side = Left, op ∈ {N, T}, dtype ∈ {f32, f64}.
// =============================================================================
unsafe extern "C" {
/// WY block T-build, `f32`. For each `(batch_slot, block_index)`,
/// builds the `[nb, nb]` upper-triangular block-reflector matrix `T`
/// such that `H_0 · ... · H_{nb-1} = I - V·T·V^T`. One CUDA block
/// per `(batch, num_blocks)` cell. Status codes: 0 success,
/// 2 invalid problem, 5 launch failure.
///
/// # Safety
/// `a_packed` ≥ `batch * M * K` `f32` cells (column-major); `tau` ≥
/// `batch * K`; `t_out` ≥ `batch * num_blocks * nb * nb`.
/// `num_blocks` must satisfy `(K + nb - 1) / nb`.
pub fn baracuda_kernels_batched_ormqr_wy_build_t_f32_run(
batch: i32,
m: i32,
k: i32,
nb: i32,
num_blocks: i32,
a_packed: *const c_void,
tau: *const c_void,
t_out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_batched_ormqr_wy_build_t_f32_can_implement` (baracuda kernels batched ormqr wy build t f32 can implement).
pub fn baracuda_kernels_batched_ormqr_wy_build_t_f32_can_implement(
batch: i32, m: i32, k: i32, nb: i32, num_blocks: i32,
) -> i32;
/// WY block T-build, `f64` analogue.
///
/// # Safety
/// Same as the `f32` variant with `f64` storage.
pub fn baracuda_kernels_batched_ormqr_wy_build_t_f64_run(
batch: i32,
m: i32,
k: i32,
nb: i32,
num_blocks: i32,
a_packed: *const c_void,
tau: *const c_void,
t_out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_batched_ormqr_wy_build_t_f64_can_implement` (baracuda kernels batched ormqr wy build t f64 can implement).
pub fn baracuda_kernels_batched_ormqr_wy_build_t_f64_can_implement(
batch: i32, m: i32, k: i32, nb: i32, num_blocks: i32,
) -> i32;
/// WY V-extraction, `f32`. Materializes the dense `V [B, M, nb]`
/// panel for one block of reflectors (block_start = `block_start`,
/// `block_k = min(nb, K - block_start)`) into a contiguous workspace
/// buffer. Sets the implicit-1 at each reflector's diagonal, copies
/// the packed-A strict lower below, zeros above the diagonal, and
/// zeros entire columns past `block_k` (handles the partial-last-
/// block case).
///
/// # Safety
/// `a_packed` ≥ `batch * M * K` `f32` cells; `v_out` ≥ `batch * M * nb`.
pub fn baracuda_kernels_batched_ormqr_wy_extract_v_f32_run(
batch: i32,
m: i32,
k: i32,
nb: i32,
block_start: i32,
block_k: i32,
a_packed: *const c_void,
v_out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_batched_ormqr_wy_extract_v_f32_can_implement` (baracuda kernels batched ormqr wy extract v f32 can implement).
pub fn baracuda_kernels_batched_ormqr_wy_extract_v_f32_can_implement(
batch: i32, m: i32, k: i32, nb: i32, block_start: i32, block_k: i32,
) -> i32;
/// WY V-extraction, `f64` analogue.
///
/// # Safety
/// Same as the `f32` variant with `f64` storage.
pub fn baracuda_kernels_batched_ormqr_wy_extract_v_f64_run(
batch: i32,
m: i32,
k: i32,
nb: i32,
block_start: i32,
block_k: i32,
a_packed: *const c_void,
v_out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_batched_ormqr_wy_extract_v_f64_can_implement` (baracuda kernels batched ormqr wy extract v f64 can implement).
pub fn baracuda_kernels_batched_ormqr_wy_extract_v_f64_can_implement(
batch: i32, m: i32, k: i32, nb: i32, block_start: i32, block_k: i32,
) -> i32;
// ----- Phase 26: complex (Complex32 / Complex64) WY-block helpers ------
//
// The Householder reflectors emitted by `cublas{C,Z}geqrfBatched` are
// unitary not orthogonal, so the WY block reflector becomes
// `H_0·...·H_{nb-1} = I - V · T · V^H` (note `V^H` — the inner dot-
// products in the T-build kernel carry a `conj()` on the left factor;
// the V-extract kernel is pure copy and needs no math change).
//
// The kernel template is shared with f32 / f64 via the `mul_T` /
// `conj_T` element helpers in `baracuda_batched_ormqr.cuh`; the only
// dtype-specific surface is the FFI launcher symbol.
/// WY block T-build, `Complex32`. f32-complex analogue of the
/// `f32` variant. Storage is `cuFloatComplex` (== `Complex32`,
/// ABI-compatible).
///
/// # Safety
/// Same as the `f32` variant with `Complex32` (interleaved
/// `(f32, f32)`) storage.
pub fn baracuda_kernels_batched_ormqr_wy_build_t_complex32_run(
batch: i32,
m: i32,
k: i32,
nb: i32,
num_blocks: i32,
a_packed: *const c_void,
tau: *const c_void,
t_out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_batched_ormqr_wy_build_t_complex32_can_implement` (baracuda kernels batched ormqr wy build t complex32 can implement).
pub fn baracuda_kernels_batched_ormqr_wy_build_t_complex32_can_implement(
batch: i32, m: i32, k: i32, nb: i32, num_blocks: i32,
) -> i32;
/// WY block T-build, `Complex64`. f64-complex analogue.
///
/// # Safety
/// Same as the `f32` variant with `Complex64` storage.
pub fn baracuda_kernels_batched_ormqr_wy_build_t_complex64_run(
batch: i32,
m: i32,
k: i32,
nb: i32,
num_blocks: i32,
a_packed: *const c_void,
tau: *const c_void,
t_out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_batched_ormqr_wy_build_t_complex64_can_implement` (baracuda kernels batched ormqr wy build t complex64 can implement).
pub fn baracuda_kernels_batched_ormqr_wy_build_t_complex64_can_implement(
batch: i32, m: i32, k: i32, nb: i32, num_blocks: i32,
) -> i32;
/// WY V-extraction, `Complex32`. f32-complex analogue. Pure copy
/// kernel — sets the implicit-1 (as `(1, 0)`), zeroes above the
/// diagonal (as `(0, 0)`), copies the strict lower below.
///
/// # Safety
/// Same as the `f32` variant with `Complex32` storage.
pub fn baracuda_kernels_batched_ormqr_wy_extract_v_complex32_run(
batch: i32,
m: i32,
k: i32,
nb: i32,
block_start: i32,
block_k: i32,
a_packed: *const c_void,
v_out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_batched_ormqr_wy_extract_v_complex32_can_implement` (baracuda kernels batched ormqr wy extract v complex32 can implement).
pub fn baracuda_kernels_batched_ormqr_wy_extract_v_complex32_can_implement(
batch: i32, m: i32, k: i32, nb: i32, block_start: i32, block_k: i32,
) -> i32;
/// WY V-extraction, `Complex64`. f64-complex analogue.
///
/// # Safety
/// Same as the `f32` variant with `Complex64` storage.
pub fn baracuda_kernels_batched_ormqr_wy_extract_v_complex64_run(
batch: i32,
m: i32,
k: i32,
nb: i32,
block_start: i32,
block_k: i32,
a_packed: *const c_void,
v_out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_batched_ormqr_wy_extract_v_complex64_can_implement` (baracuda kernels batched ormqr wy extract v complex64 can implement).
pub fn baracuda_kernels_batched_ormqr_wy_extract_v_complex64_can_implement(
batch: i32, m: i32, k: i32, nb: i32, block_start: i32, block_k: i32,
) -> i32;
}
// =============================================================================
// cuDNN — convolution / pooling / CTC family infrastructure
// =============================================================================
//
// Gated behind the `cudnn` cargo feature. cuDNN is a separate NVIDIA
// download not bundled with the stock CUDA toolkit; the entire block
// below (types + constants + extern "C" fns) is invisible when the
// feature is off, so the rest of `baracuda-kernels-sys` builds cleanly
// on machines without cuDNN. Re-exported at the crate root via
// `pub use cudnn_ffi::*;` at the end of the module so callers don't
// see the wrapper.
//
// Phase 7 wraps cuDNN's legacy "v6" descriptor-based API (handle +
// tensor / filter / op descriptors → exec). The graph API (introduced
// in cuDNN 8 and reworked further in cuDNN 9) is a separate axis we
// defer for now — the legacy API is stable across cuDNN 8 / 9 and
// covers everything Milestones 7.1 (Conv2d) / 7.2 (Pooling) / 7.3
// (CTC) need.
//
// Linkage: `cargo:rustc-link-lib=dylib=cudnn` (added in build.rs). On
// Linux resolves to `libcudnn.so`; on Windows to `cudnn64_*.dll` from
// `CUDA_PATH\bin` (the cuDNN installer drops the DLLs alongside the
// CUDA toolkit by default).
//
// Conv-specific descriptor types + the convolution exec / workspace
// FFI live in this block (Milestone 7.1, this session). Pooling and
// CTC agents extend with their own descriptors only — the handle,
// tensor / filter descriptor types, data-type / format / NaN-prop
// constants, and create / destroy helpers below are shared.
#[cfg(feature = "cudnn")]
mod cudnn_ffi {
use super::*;
/// Opaque cuDNN handle. Stateful object — the plan layer creates one
/// lazily on first `run` and reuses it across launches. Like cuBLAS /
/// cuSOLVER, the handle is **not** thread-safe; the plan owns it via a
/// `Cell<>` and is `!Sync` / `!Send` by virtue of that interior
/// mutability.
#[allow(non_camel_case_types)]
pub type cudnnHandle_t = *mut c_void;
/// Opaque cuDNN tensor descriptor. Carries shape + layout + element
/// type for an n-D tensor operand. Reused across launches by the plan
/// layer (set once on first `run`, mutated only if the descriptor
/// shape changes).
#[allow(non_camel_case_types)]
pub type cudnnTensorDescriptor_t = *mut c_void;
/// Opaque cuDNN filter descriptor. Same shape as
/// [`cudnnTensorDescriptor_t`] but carries an `[output_channels,
/// input_channels, ...]` filter-bank tensor (semantics differ from a
/// plain n-D tensor — cuDNN uses this to express convolution weights,
/// transposed-convolution weights, etc.).
#[allow(non_camel_case_types)]
pub type cudnnFilterDescriptor_t = *mut c_void;
/// Opaque cuDNN convolution descriptor. Carries the
/// (pad, stride, dilation, mode, accumulator-dtype) tuple. Reused
/// across launches by the conv plan.
#[allow(non_camel_case_types)]
pub type cudnnConvolutionDescriptor_t = *mut c_void;
/// cuDNN data-type tag — `cudnnDataType_t` from the C header.
/// Used by `cudnnSetTensor4dDescriptor`, `cudnnSetFilter4dDescriptor`,
/// `cudnnSetConvolution2dDescriptor` to encode the per-operand element
/// type.
#[allow(non_camel_case_types)]
pub type cudnnDataType_t = i32;
/// `CUDNN_DATA_FLOAT` — `f32`.
pub const CUDNN_DATA_FLOAT: i32 = 0;
/// `CUDNN_DATA_DOUBLE` — `f64`.
pub const CUDNN_DATA_DOUBLE: i32 = 1;
/// `CUDNN_DATA_HALF` — IEEE 754 `f16`.
pub const CUDNN_DATA_HALF: i32 = 2;
/// `CUDNN_DATA_BFLOAT16` — Google `bf16`.
pub const CUDNN_DATA_BFLOAT16: i32 = 9;
/// cuDNN tensor format tag — `cudnnTensorFormat_t` from the C header.
/// Selects between channel-first (NCHW, PyTorch default) and channel-
/// last (NHWC, TensorFlow default) storage. Trailblazer wires NCHW
/// only; NHWC is a follow-up.
#[allow(non_camel_case_types)]
pub type cudnnTensorFormat_t = i32;
/// `CUDNN_TENSOR_NCHW` — channel-first storage (PyTorch / default).
pub const CUDNN_TENSOR_NCHW: i32 = 0;
/// `CUDNN_TENSOR_NHWC` — channel-last storage (TensorFlow / default).
pub const CUDNN_TENSOR_NHWC: i32 = 1;
/// cuDNN NaN-propagation tag — `cudnnNanPropagation_t` from the C
/// header. Controls whether pooling / activation kernels propagate
/// NaN values through reductions (matters for max-pool etc.).
#[allow(non_camel_case_types)]
pub type cudnnNanPropagation_t = i32;
/// `CUDNN_NOT_PROPAGATE_NAN` — NaN inputs are ignored (treated as
/// `-inf` for max, `+inf` for min). cuDNN default for the legacy
/// pooling API; matches PyTorch's pooling-on-NaN-input behavior.
pub const CUDNN_NOT_PROPAGATE_NAN: i32 = 0;
/// `CUDNN_PROPAGATE_NAN` — NaN inputs propagate through the reduction.
pub const CUDNN_PROPAGATE_NAN: i32 = 1;
/// cuDNN convolution-mode tag — `cudnnConvolutionMode_t` from the C
/// header. Selects between mathematical convolution (kernel flipped
/// before the multiply-accumulate) and cross-correlation (kernel
/// applied directly). **PyTorch's `torch.nn.Conv2d` is actually
/// cross-correlation** despite the name — pass
/// [`CUDNN_CROSS_CORRELATION`] for PyTorch parity.
#[allow(non_camel_case_types)]
pub type cudnnConvolutionMode_t = i32;
/// `CUDNN_CONVOLUTION` — true mathematical convolution (kernel
/// flipped). Rarely what callers want — PyTorch's `Conv2d` is
/// cross-correlation.
pub const CUDNN_CONVOLUTION: i32 = 0;
/// `CUDNN_CROSS_CORRELATION` — kernel applied directly (PyTorch
/// `Conv2d` semantics). Use this unless you have a specific reason
/// to flip the kernel.
pub const CUDNN_CROSS_CORRELATION: i32 = 1;
/// cuDNN forward-convolution algorithm tag —
/// `cudnnConvolutionFwdAlgo_t` from the C header. Trailblazer pins
/// [`CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM`] (algo `0`) as the
/// universally supported baseline; the heuristic search via
/// `cudnnGetConvolutionForwardAlgorithm_v7` is a follow-up.
#[allow(non_camel_case_types)]
pub type cudnnConvolutionFwdAlgo_t = i32;
/// `CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM` — universal-coverage
/// baseline (works for any input/filter shape that cuDNN itself
/// accepts).
pub const CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM: i32 = 0;
/// cuDNN backward-data-convolution algorithm tag —
/// `cudnnConvolutionBwdDataAlgo_t` from the C header. Trailblazer
/// pins [`CUDNN_CONVOLUTION_BWD_DATA_ALGO_1`] (algo `1`, the
/// `IMPLICIT_PRECOMP_GEMM`-style universal baseline for backward
/// data).
#[allow(non_camel_case_types)]
pub type cudnnConvolutionBwdDataAlgo_t = i32;
/// `CUDNN_CONVOLUTION_BWD_DATA_ALGO_1` — universal-coverage baseline
/// for the data-gradient pass.
pub const CUDNN_CONVOLUTION_BWD_DATA_ALGO_1: i32 = 1;
/// cuDNN backward-filter-convolution algorithm tag —
/// `cudnnConvolutionBwdFilterAlgo_t` from the C header. Trailblazer
/// pins [`CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1`].
#[allow(non_camel_case_types)]
pub type cudnnConvolutionBwdFilterAlgo_t = i32;
/// `CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1` — universal-coverage
/// baseline for the filter-gradient pass.
pub const CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1: i32 = 1;
/// `CUDNN_STATUS_SUCCESS` — the only success code. Any non-zero
/// return from a cuDNN routine is mapped to a negative status at the
/// safe-plan layer for distinct error reporting (mirrors the
/// `CUSOLVER_STATUS_SUCCESS` convention).
pub const CUDNN_STATUS_SUCCESS: i32 = 0;
/// Opaque cuDNN CTC-loss descriptor. Carries the `(compute_type,
/// norm_mode, nan_prop)` tuple configured via
/// [`cudnnSetCTCLossDescriptorEx`]. Owned by the CTC plan, reused
/// across launches.
#[allow(non_camel_case_types)]
pub type cudnnCTCLossDescriptor_t = *mut c_void;
/// cuDNN CTC-loss algorithm tag — `cudnnCTCLossAlgo_t` from the C
/// header. Selects between deterministic and non-deterministic
/// internal algorithms (cuDNN's non-deterministic variant uses
/// atomic-add reductions for higher throughput on large batches).
#[allow(non_camel_case_types)]
pub type cudnnCTCLossAlgo_t = i32;
/// `CUDNN_CTC_LOSS_ALGO_DETERMINISTIC` — bit-stable across runs
/// on the same hardware. Trailblazer default for parity with the
/// bespoke CTC plan.
pub const CUDNN_CTC_LOSS_ALGO_DETERMINISTIC: i32 = 0;
/// `CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC` — faster on large
/// batches but introduces atomicAdd-induced non-determinism.
pub const CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC: i32 = 1;
/// cuDNN loss-normalization-mode tag —
/// `cudnnLossNormalizationMode_t` from the C header. Selects how
/// cuDNN interprets the input probability tensor for losses that
/// can take either raw probabilities or pre-softmaxed log-probs.
#[allow(non_camel_case_types)]
pub type cudnnLossNormalizationMode_t = i32;
/// `CUDNN_LOSS_NORMALIZATION_NONE` — input is raw probabilities;
/// caller is responsible for the softmax (and the entries must
/// sum to 1 along the class axis).
pub const CUDNN_LOSS_NORMALIZATION_NONE: i32 = 0;
/// `CUDNN_LOSS_NORMALIZATION_SOFTMAX` — input is log-probs;
/// cuDNN applies log-softmax internally to recover the
/// normalization. Trailblazer default for parity with the
/// bespoke CTC plan's `log_probs` convention.
pub const CUDNN_LOSS_NORMALIZATION_SOFTMAX: i32 = 1;
// ----- pooling (Milestone 7.2) -------------------------------------------
/// Opaque cuDNN pooling descriptor. Carries the
/// `(mode, nan_prop, window, pad, stride)` tuple. Reused across launches
/// by the pool plan.
#[allow(non_camel_case_types)]
pub type cudnnPoolingDescriptor_t = *mut c_void;
/// cuDNN pooling-mode tag — `cudnnPoolingMode_t` from the C header.
/// Selects max / avg-include-padding / avg-exclude-padding /
/// max-deterministic. PyTorch's `nn.MaxPool2d` corresponds to
/// [`CUDNN_POOLING_MAX`]; `nn.AvgPool2d` with the PyTorch default
/// `count_include_pad=False` corresponds to
/// [`CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING`].
#[allow(non_camel_case_types)]
pub type cudnnPoolingMode_t = i32;
/// `CUDNN_POOLING_MAX` — pick the maximum element in each window.
/// PyTorch / TensorFlow default for max-pool.
pub const CUDNN_POOLING_MAX: i32 = 0;
/// `CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING` — average over the
/// full `window_h * window_w` denominator (padded zeros included).
pub const CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING: i32 = 1;
/// `CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING` — average over the
/// number of *valid* (non-padded) elements per window. PyTorch's
/// `nn.AvgPool2d` default (`count_include_pad=False`).
pub const CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING: i32 = 2;
/// `CUDNN_POOLING_MAX_DETERMINISTIC` — same semantics as
/// [`CUDNN_POOLING_MAX`] but uses cuDNN's deterministic max-pool path
/// for reproducible bit-identical output across launches (slightly
/// slower).
pub const CUDNN_POOLING_MAX_DETERMINISTIC: i32 = 3;
unsafe extern "C" {
// ----- handle lifecycle ----------------------------------------------
/// `cudnnCreate(handle)` — allocate a cuDNN handle. Returns 0 on
/// success.
///
/// # Safety
/// `handle` must point to writable storage for one
/// [`cudnnHandle_t`].
pub fn cudnnCreate(handle: *mut cudnnHandle_t) -> i32;
/// `cudnnDestroy(handle)` — release a cuDNN handle. Returns 0 on
/// success.
///
/// # Safety
/// `handle` must be a valid handle returned by [`cudnnCreate`]
/// that has not been previously destroyed.
pub fn cudnnDestroy(handle: cudnnHandle_t) -> i32;
/// `cudnnSetStream(handle, stream)` — bind subsequent cuDNN calls
/// to the given CUDA stream. Returns 0 on success.
///
/// # Safety
/// `handle` must be a live cuDNN handle; `stream` must be a
/// valid CUDA stream in the current context (or null for the
/// default stream).
pub fn cudnnSetStream(handle: cudnnHandle_t, stream: *mut c_void) -> i32;
// ----- tensor descriptor lifecycle ----------------------------------
/// `cudnnCreateTensorDescriptor(desc)` — allocate a tensor
/// descriptor. Returns 0 on success.
///
/// # Safety
/// `desc` must point to writable storage for one
/// [`cudnnTensorDescriptor_t`].
pub fn cudnnCreateTensorDescriptor(desc: *mut cudnnTensorDescriptor_t) -> i32;
/// `cudnnDestroyTensorDescriptor(desc)` — release a tensor
/// descriptor. Returns 0 on success.
///
/// # Safety
/// `desc` must be a valid descriptor returned by
/// [`cudnnCreateTensorDescriptor`] that has not been previously
/// destroyed.
pub fn cudnnDestroyTensorDescriptor(desc: cudnnTensorDescriptor_t) -> i32;
/// `cudnnSetTensor4dDescriptor(desc, format, dtype, n, c, h, w)`
/// — configure a rank-4 (`NCHW` or `NHWC`) tensor descriptor.
/// Returns 0 on success.
///
/// `format` is a [`cudnnTensorFormat_t`] constant
/// ([`CUDNN_TENSOR_NCHW`] / [`CUDNN_TENSOR_NHWC`]); `dtype` is a
/// [`cudnnDataType_t`] constant. The `(n, c, h, w)` axes always
/// follow the NCHW logical order regardless of `format` — the
/// `format` arg only changes how cuDNN computes the underlying
/// strides.
///
/// # Safety
/// `desc` must be a live tensor descriptor.
pub fn cudnnSetTensor4dDescriptor(
desc: cudnnTensorDescriptor_t,
format: cudnnTensorFormat_t,
dtype: cudnnDataType_t,
n: i32,
c: i32,
h: i32,
w: i32,
) -> i32;
// ----- filter descriptor lifecycle ----------------------------------
/// `cudnnCreateFilterDescriptor(desc)` — allocate a filter
/// descriptor. Returns 0 on success.
///
/// # Safety
/// `desc` must point to writable storage for one
/// [`cudnnFilterDescriptor_t`].
pub fn cudnnCreateFilterDescriptor(desc: *mut cudnnFilterDescriptor_t) -> i32;
/// `cudnnDestroyFilterDescriptor(desc)` — release a filter
/// descriptor. Returns 0 on success.
///
/// # Safety
/// `desc` must be a valid descriptor returned by
/// [`cudnnCreateFilterDescriptor`] that has not been previously
/// destroyed.
pub fn cudnnDestroyFilterDescriptor(desc: cudnnFilterDescriptor_t) -> i32;
/// `cudnnSetFilter4dDescriptor(desc, dtype, format, k, c, h, w)`
/// — configure a rank-4 filter bank descriptor. Returns 0 on
/// success.
///
/// `k` is `output_channels` (number of filters); `c` is
/// `input_channels` per filter; `(h, w)` are the per-filter
/// spatial extents. Note: the argument order in the cuDNN C ABI
/// puts `dtype` *before* `format` for filter descriptors
/// (opposite of `cudnnSetTensor4dDescriptor`); we mirror that
/// here.
///
/// # Safety
/// `desc` must be a live filter descriptor.
pub fn cudnnSetFilter4dDescriptor(
desc: cudnnFilterDescriptor_t,
dtype: cudnnDataType_t,
format: cudnnTensorFormat_t,
k: i32,
c: i32,
h: i32,
w: i32,
) -> i32;
// ----- convolution descriptor lifecycle -----------------------------
/// `cudnnCreateConvolutionDescriptor(desc)` — allocate a
/// convolution descriptor. Returns 0 on success.
///
/// # Safety
/// `desc` must point to writable storage for one
/// [`cudnnConvolutionDescriptor_t`].
pub fn cudnnCreateConvolutionDescriptor(
desc: *mut cudnnConvolutionDescriptor_t,
) -> i32;
/// `cudnnDestroyConvolutionDescriptor(desc)` — release a
/// convolution descriptor. Returns 0 on success.
///
/// # Safety
/// `desc` must be a valid descriptor returned by
/// [`cudnnCreateConvolutionDescriptor`] that has not been
/// previously destroyed.
pub fn cudnnDestroyConvolutionDescriptor(
desc: cudnnConvolutionDescriptor_t,
) -> i32;
/// `cudnnSetConvolution2dDescriptor(desc, pad_h, pad_w, u, v,
/// dilation_h, dilation_w, mode, compute_type)` — configure a
/// 2-D convolution descriptor. Returns 0 on success.
///
/// `u`/`v` are the vertical / horizontal **stride**;
/// `dilation_h`/`dilation_w` are the per-axis dilation factors;
/// `mode` is a [`cudnnConvolutionMode_t`] (pass
/// [`CUDNN_CROSS_CORRELATION`] for PyTorch parity);
/// `compute_type` is the **accumulator** dtype (a
/// [`cudnnDataType_t`] constant — typically `f32` even when the
/// operand dtype is `f16` / `bf16`).
///
/// # Safety
/// `desc` must be a live convolution descriptor.
pub fn cudnnSetConvolution2dDescriptor(
desc: cudnnConvolutionDescriptor_t,
pad_h: i32,
pad_w: i32,
u: i32,
v: i32,
dilation_h: i32,
dilation_w: i32,
mode: cudnnConvolutionMode_t,
compute_type: cudnnDataType_t,
) -> i32;
// ----- workspace size queries ---------------------------------------
/// `cudnnGetConvolutionForwardWorkspaceSize(handle, x_desc,
/// w_desc, conv_desc, y_desc, algo, &size)` — query the
/// workspace size (in bytes) the FW conv kernel needs for the
/// given (x_desc, w_desc, conv_desc, y_desc, algo) combination.
/// Returns 0 on success.
///
/// # Safety
/// All descriptors must be live and consistent; `size_in_bytes`
/// must point to writable storage for one `usize` (cuDNN's
/// `size_t`).
pub fn cudnnGetConvolutionForwardWorkspaceSize(
handle: cudnnHandle_t,
x_desc: cudnnTensorDescriptor_t,
w_desc: cudnnFilterDescriptor_t,
conv_desc: cudnnConvolutionDescriptor_t,
y_desc: cudnnTensorDescriptor_t,
algo: cudnnConvolutionFwdAlgo_t,
size_in_bytes: *mut usize,
) -> i32;
/// `cudnnGetConvolutionBackwardDataWorkspaceSize(handle, w_desc,
/// dy_desc, conv_desc, dx_desc, algo, &size)` — query the
/// workspace size (in bytes) the BW-data conv kernel needs.
/// Returns 0 on success.
///
/// # Safety
/// As for the FW variant.
pub fn cudnnGetConvolutionBackwardDataWorkspaceSize(
handle: cudnnHandle_t,
w_desc: cudnnFilterDescriptor_t,
dy_desc: cudnnTensorDescriptor_t,
conv_desc: cudnnConvolutionDescriptor_t,
dx_desc: cudnnTensorDescriptor_t,
algo: cudnnConvolutionBwdDataAlgo_t,
size_in_bytes: *mut usize,
) -> i32;
/// `cudnnGetConvolutionBackwardFilterWorkspaceSize(handle,
/// x_desc, dy_desc, conv_desc, dw_desc, algo, &size)` — query
/// the workspace size (in bytes) the BW-filter conv kernel needs.
/// Returns 0 on success.
///
/// # Safety
/// As for the FW variant.
pub fn cudnnGetConvolutionBackwardFilterWorkspaceSize(
handle: cudnnHandle_t,
x_desc: cudnnTensorDescriptor_t,
dy_desc: cudnnTensorDescriptor_t,
conv_desc: cudnnConvolutionDescriptor_t,
dw_desc: cudnnFilterDescriptor_t,
algo: cudnnConvolutionBwdFilterAlgo_t,
size_in_bytes: *mut usize,
) -> i32;
// ----- convolution exec ---------------------------------------------
/// `cudnnConvolutionForward(handle, alpha, x_desc, x, w_desc, w,
/// conv_desc, algo, workspace, workspace_bytes, beta, y_desc, y)`
/// — run the FW conv. Computes
/// `y := alpha · conv(x, w) + beta · y`. Returns 0 on success.
///
/// `alpha` and `beta` point to a single host-side scalar of the
/// **accumulator** dtype (typically `f32` — same value as the
/// `compute_type` passed to `cudnnSetConvolution2dDescriptor`).
/// For pure-store semantics pass `*alpha = 1.0, *beta = 0.0`.
///
/// # Safety
/// All pointers must be device-resident where applicable; the
/// scalar pointers may point to host or device memory (matching
/// the handle's current scalar mode — host by default). Stream
/// is implicit via [`cudnnSetStream`].
pub fn cudnnConvolutionForward(
handle: cudnnHandle_t,
alpha: *const c_void,
x_desc: cudnnTensorDescriptor_t,
x: *const c_void,
w_desc: cudnnFilterDescriptor_t,
w: *const c_void,
conv_desc: cudnnConvolutionDescriptor_t,
algo: cudnnConvolutionFwdAlgo_t,
workspace: *mut c_void,
workspace_bytes: usize,
beta: *const c_void,
y_desc: cudnnTensorDescriptor_t,
y: *mut c_void,
) -> i32;
/// `cudnnConvolutionBackwardData(handle, alpha, w_desc, w,
/// dy_desc, dy, conv_desc, algo, workspace, workspace_bytes,
/// beta, dx_desc, dx)` — data-gradient pass. Computes
/// `dx := alpha · conv^T(w, dy) + beta · dx`. Returns 0 on
/// success.
///
/// # Safety
/// As for [`cudnnConvolutionForward`].
pub fn cudnnConvolutionBackwardData(
handle: cudnnHandle_t,
alpha: *const c_void,
w_desc: cudnnFilterDescriptor_t,
w: *const c_void,
dy_desc: cudnnTensorDescriptor_t,
dy: *const c_void,
conv_desc: cudnnConvolutionDescriptor_t,
algo: cudnnConvolutionBwdDataAlgo_t,
workspace: *mut c_void,
workspace_bytes: usize,
beta: *const c_void,
dx_desc: cudnnTensorDescriptor_t,
dx: *mut c_void,
) -> i32;
/// `cudnnConvolutionBackwardFilter(handle, alpha, x_desc, x,
/// dy_desc, dy, conv_desc, algo, workspace, workspace_bytes,
/// beta, dw_desc, dw)` — filter-gradient pass. Computes
/// `dw := alpha · conv_grad(x, dy) + beta · dw`. Returns 0 on
/// success.
///
/// # Safety
/// As for [`cudnnConvolutionForward`].
pub fn cudnnConvolutionBackwardFilter(
handle: cudnnHandle_t,
alpha: *const c_void,
x_desc: cudnnTensorDescriptor_t,
x: *const c_void,
dy_desc: cudnnTensorDescriptor_t,
dy: *const c_void,
conv_desc: cudnnConvolutionDescriptor_t,
algo: cudnnConvolutionBwdFilterAlgo_t,
workspace: *mut c_void,
workspace_bytes: usize,
beta: *const c_void,
dw_desc: cudnnFilterDescriptor_t,
dw: *mut c_void,
) -> i32;
// ----- n-D tensor descriptor (used by CTC for the [T, B, C] probs/grads
// tensor) -----------------------------------------------------------------
/// `cudnnSetTensorNdDescriptor(desc, dtype, nb_dims, dim_a, stride_a)`
/// — configure an n-D tensor descriptor with caller-supplied
/// dimensions and strides. Returns 0 on success.
///
/// Used by the CTC plan to describe the rank-3 `[T, B, C]`
/// log-probability / gradient tensors that don't fit the rank-4
/// `NCHW` / `NHWC` shape that [`cudnnSetTensor4dDescriptor`] expects.
/// Also used by Conv1d / Conv3d / ConvTranspose plans for rank-3
/// `[N, C, L]` and rank-5 `[N, C, D, H, W]` activation tensors.
///
/// # Safety
/// `desc` must be a live tensor descriptor; `dim_a` and `stride_a`
/// must each point to at least `nb_dims` readable `i32` values.
pub fn cudnnSetTensorNdDescriptor(
desc: cudnnTensorDescriptor_t,
dtype: cudnnDataType_t,
nb_dims: i32,
dim_a: *const i32,
stride_a: *const i32,
) -> i32;
/// `cudnnSetFilterNdDescriptor(desc, dtype, format, nb_dims, dim_a)`
/// — configure an n-D filter descriptor. Returns 0 on success.
///
/// Used by Conv1d / Conv3d / ConvTranspose plans to describe the
/// `[C_out, C_in_per_group, L_filt]` (rank-3) and `[C_out,
/// C_in_per_group, D_filt, H_filt, W_filt]` (rank-5) filter shapes.
/// The leading two axes are always `K = C_out` and `C =
/// C_in / groups`.
///
/// # Safety
/// `desc` must be a live filter descriptor; `dim_a` must point to
/// at least `nb_dims` readable `i32` values.
pub fn cudnnSetFilterNdDescriptor(
desc: cudnnFilterDescriptor_t,
dtype: cudnnDataType_t,
format: cudnnTensorFormat_t,
nb_dims: i32,
dim_a: *const i32,
) -> i32;
/// `cudnnSetConvolutionNdDescriptor(desc, array_length, pad_a,
/// stride_a, dilation_a, mode, compute_type)` — configure an n-D
/// convolution descriptor. Returns 0 on success.
///
/// `array_length` is the spatial rank (1 for Conv1d, 2 for Conv2d,
/// 3 for Conv3d) — **not** the tensor rank. Activation and filter
/// tensors carry two extra axes (`[N, C, ...]` / `[K, C, ...]`)
/// not counted here. `pad_a` / `stride_a` / `dilation_a` must each
/// point to `array_length` `i32` values.
///
/// # Safety
/// `desc` must be a live convolution descriptor; the three array
/// pointers must each point to at least `array_length` readable
/// `i32` values.
pub fn cudnnSetConvolutionNdDescriptor(
desc: cudnnConvolutionDescriptor_t,
array_length: i32,
pad_a: *const i32,
stride_a: *const i32,
dilation_a: *const i32,
mode: cudnnConvolutionMode_t,
compute_type: cudnnDataType_t,
) -> i32;
/// `cudnnSetConvolutionGroupCount(desc, group_count)` — set the
/// group count for grouped / depthwise convolution. Returns 0 on
/// success. `group_count == 1` (the cuDNN default) is plain dense
/// convolution; `group_count == c_in` is depthwise.
///
/// Must be called **after** `cudnnSetConvolution{2d,Nd}Descriptor`
/// — that call resets the group count to 1.
///
/// # Safety
/// `desc` must be a live convolution descriptor.
pub fn cudnnSetConvolutionGroupCount(
desc: cudnnConvolutionDescriptor_t,
group_count: i32,
) -> i32;
// ----- CTC loss --------------------------------------------------------
/// `cudnnCreateCTCLossDescriptor(ctc_desc)` — allocate a CTC-loss
/// descriptor. Returns 0 on success.
///
/// # Safety
/// `ctc_desc` must point to writable storage for one
/// [`cudnnCTCLossDescriptor_t`].
pub fn cudnnCreateCTCLossDescriptor(ctc_desc: *mut cudnnCTCLossDescriptor_t) -> i32;
/// `cudnnDestroyCTCLossDescriptor(ctc_desc)` — release a CTC-loss
/// descriptor. Returns 0 on success.
///
/// # Safety
/// `ctc_desc` must be a valid descriptor returned by
/// [`cudnnCreateCTCLossDescriptor`] that has not been previously
/// destroyed.
pub fn cudnnDestroyCTCLossDescriptor(ctc_desc: cudnnCTCLossDescriptor_t) -> i32;
/// `cudnnSetCTCLossDescriptorEx(ctc_desc, comp_type, norm_mode,
/// grad_mode)` — configure a CTC-loss descriptor.
///
/// `comp_type` is the compute (accumulator) [`cudnnDataType_t`]
/// ([`CUDNN_DATA_FLOAT`] or [`CUDNN_DATA_DOUBLE`]). `norm_mode` is a
/// [`cudnnLossNormalizationMode_t`]: pass
/// [`CUDNN_LOSS_NORMALIZATION_SOFTMAX`] when the input is log-probs
/// (cuDNN softmaxes internally); pass
/// [`CUDNN_LOSS_NORMALIZATION_NONE`] for raw probabilities.
/// `grad_mode` is a [`cudnnNanPropagation_t`] controlling whether
/// NaN gradients propagate.
///
/// # Safety
/// `ctc_desc` must be a live CTC-loss descriptor.
pub fn cudnnSetCTCLossDescriptorEx(
ctc_desc: cudnnCTCLossDescriptor_t,
comp_type: cudnnDataType_t,
norm_mode: cudnnLossNormalizationMode_t,
grad_mode: cudnnNanPropagation_t,
) -> i32;
/// `cudnnGetCTCLossWorkspaceSize(handle, probs_desc, grads_desc,
/// labels, label_lengths, input_lengths, algo, ctc_desc,
/// size_in_bytes)` — compute the device workspace needed for a
/// CTC-loss forward+backward fused call.
///
/// `labels`, `label_lengths`, `input_lengths` are **host-side**
/// `i32` arrays. `labels` is the concatenation of per-batch
/// label sequences (length `Σ label_lengths`); `label_lengths`
/// and `input_lengths` are each length `B` (batch).
///
/// # Safety
/// All descriptors must be live; host arrays must be valid for
/// their stated lengths; `size_in_bytes` must point to writable
/// storage for one `usize`.
pub fn cudnnGetCTCLossWorkspaceSize(
handle: cudnnHandle_t,
probs_desc: cudnnTensorDescriptor_t,
grads_desc: cudnnTensorDescriptor_t,
labels: *const i32,
label_lengths: *const i32,
input_lengths: *const i32,
algo: cudnnCTCLossAlgo_t,
ctc_desc: cudnnCTCLossDescriptor_t,
size_in_bytes: *mut usize,
) -> i32;
/// `cudnnCTCLoss(handle, probs_desc, probs, labels, label_lengths,
/// input_lengths, costs, grads_desc, grads, algo, ctc_desc,
/// workspace, workspace_bytes)` — fused CTC forward + backward.
///
/// `probs` is a device pointer to a `[T, B, C]` log-prob (or raw
/// prob; see `norm_mode`) tensor; `costs` is a device pointer to a
/// `[B]` per-sample loss buffer; `grads` is a device pointer to a
/// `[T, B, C]` gradient buffer (or null for FW-only). `labels`,
/// `label_lengths`, `input_lengths` are **host-side** `i32`
/// arrays as for [`cudnnGetCTCLossWorkspaceSize`].
///
/// # Safety
/// All descriptors must be live; host / device pointers must be
/// valid for the stated shapes; `workspace` must point to at
/// least `workspace_bytes` of device memory (the value returned
/// by [`cudnnGetCTCLossWorkspaceSize`]). Stream attachment is
/// implicit via [`cudnnSetStream`].
pub fn cudnnCTCLoss(
handle: cudnnHandle_t,
probs_desc: cudnnTensorDescriptor_t,
probs: *const c_void,
labels: *const i32,
label_lengths: *const i32,
input_lengths: *const i32,
costs: *mut c_void,
grads_desc: cudnnTensorDescriptor_t,
grads: *mut c_void,
algo: cudnnCTCLossAlgo_t,
ctc_desc: cudnnCTCLossDescriptor_t,
workspace: *mut c_void,
workspace_bytes: usize,
) -> i32;
// ----- pooling descriptor lifecycle (Milestone 7.2) ----------------
/// `cudnnCreatePoolingDescriptor(desc)` — allocate a pooling
/// descriptor. Returns 0 on success.
///
/// # Safety
/// `desc` must point to writable storage for one
/// [`cudnnPoolingDescriptor_t`].
pub fn cudnnCreatePoolingDescriptor(desc: *mut cudnnPoolingDescriptor_t) -> i32;
/// `cudnnDestroyPoolingDescriptor(desc)` — release a pooling
/// descriptor. Returns 0 on success.
///
/// # Safety
/// `desc` must be a valid descriptor returned by
/// [`cudnnCreatePoolingDescriptor`] that has not been previously
/// destroyed.
pub fn cudnnDestroyPoolingDescriptor(desc: cudnnPoolingDescriptor_t) -> i32;
/// `cudnnSetPooling2dDescriptor(desc, mode, nan_prop, window_h,
/// window_w, pad_h, pad_w, stride_h, stride_w)` — configure a 2-D
/// pooling descriptor. Returns 0 on success.
///
/// `mode` is a [`cudnnPoolingMode_t`]; `nan_propagation` is a
/// [`cudnnNanPropagation_t`] (typically [`CUDNN_NOT_PROPAGATE_NAN`],
/// matching PyTorch). Output spatial extents are computed by the FW
/// kernel as `out = floor((in + 2·pad - window) / stride) + 1`.
///
/// # Safety
/// `desc` must be a live pooling descriptor.
pub fn cudnnSetPooling2dDescriptor(
desc: cudnnPoolingDescriptor_t,
mode: cudnnPoolingMode_t,
nan_propagation: cudnnNanPropagation_t,
window_h: i32,
window_w: i32,
pad_h: i32,
pad_w: i32,
stride_h: i32,
stride_w: i32,
) -> i32;
/// `cudnnSetPoolingNdDescriptor(desc, mode, nan_prop, nb_dims,
/// window_a, padding_a, stride_a)` — configure an N-dimensional
/// pooling descriptor. `nb_dims` is the count of *spatial* axes
/// (1, 2, or 3 — i.e. excluding the leading `N` and `C` axes of
/// the data tensor). `window_a`, `padding_a`, and `stride_a` each
/// point to `nb_dims` readable `i32` values describing the window
/// extent, zero-padding, and stride along each spatial axis in
/// order. Returns 0 on success.
///
/// Used by 1-D / 3-D pooling plans (and by the cuDNN approximation
/// of adaptive pool) — they bind a matching rank-3 / rank-5 tensor
/// descriptor via [`cudnnSetTensorNdDescriptor`] and then drive
/// [`cudnnPoolingForward`] / [`cudnnPoolingBackward`] with the
/// resulting Nd pooling descriptor. The exec entry points are
/// rank-agnostic — they don't take the rank as an argument and
/// instead read it off the descriptor.
///
/// # Safety
/// `desc` must be a live pooling descriptor. `window_a`,
/// `padding_a`, and `stride_a` must each be readable for at least
/// `nb_dims` `i32` values.
pub fn cudnnSetPoolingNdDescriptor(
desc: cudnnPoolingDescriptor_t,
mode: cudnnPoolingMode_t,
nan_propagation: cudnnNanPropagation_t,
nb_dims: i32,
window_a: *const i32,
padding_a: *const i32,
stride_a: *const i32,
) -> i32;
// ----- pooling exec ------------------------------------------------
/// `cudnnPoolingForward(handle, pool_desc, alpha, x_desc, x, beta,
/// y_desc, y)` — FW pool. Computes
/// `y := alpha · pool(x) + beta · y`. Returns 0 on success.
///
/// `alpha` / `beta` point to a single host-side scalar of the
/// accumulator dtype (`f32` for `f32` / `f16` / `bf16` operands;
/// `f64` for `f64` operands). For pure-store semantics pass
/// `*alpha = 1, *beta = 0`. **No workspace argument** — cuDNN's
/// pooling kernel allocates its small internal scratch itself.
///
/// # Safety
/// All descriptors must be live and consistent with the operand
/// buffers. Stream is implicit via [`cudnnSetStream`].
pub fn cudnnPoolingForward(
handle: cudnnHandle_t,
pool_desc: cudnnPoolingDescriptor_t,
alpha: *const c_void,
x_desc: cudnnTensorDescriptor_t,
x: *const c_void,
beta: *const c_void,
y_desc: cudnnTensorDescriptor_t,
y: *mut c_void,
) -> i32;
/// `cudnnPoolingBackward(handle, pool_desc, alpha, y_desc, y,
/// dy_desc, dy, x_desc, x, beta, dx_desc, dx)` — BW pool. Computes
/// `dx := alpha · pool_grad(y, dy, x) + beta · dx`. Returns 0 on
/// success.
///
/// **Both `y` (saved FW output) and `x` (saved FW input) must be
/// retained from the FW launch** — cuDNN uses `y` and `x`
/// together to recover the per-window argmax for max-pool (no
/// separate "indices" tensor is materialized by the legacy API).
/// For average-pool the gradient depends only on `x` (window
/// shape + denominator), but the prototype still requires `y` for
/// API uniformity.
///
/// # Safety
/// As for [`cudnnPoolingForward`]. All four data buffers must be
/// live and shape-consistent with their descriptors.
pub fn cudnnPoolingBackward(
handle: cudnnHandle_t,
pool_desc: cudnnPoolingDescriptor_t,
alpha: *const c_void,
y_desc: cudnnTensorDescriptor_t,
y: *const c_void,
dy_desc: cudnnTensorDescriptor_t,
dy: *const c_void,
x_desc: cudnnTensorDescriptor_t,
x: *const c_void,
beta: *const c_void,
dx_desc: cudnnTensorDescriptor_t,
dx: *mut c_void,
) -> i32;
}
} // end mod cudnn_ffi
#[cfg(feature = "cudnn")]
pub use cudnn_ffi::*;
// Phase 19.1 — cuDNN pool FFI facade (non-adaptive MaxPool / AvgPool
// 1D / 2D / 3D, FW + BW × 4 fp dtypes). Pure-Rust `#[no_mangle]`
// wrappers exposing the cuDNN-backed pool plans as flat C symbols for
// non-Rust callers (Fuel). See module docs for the handle-lifecycle
// + indices contract.
#[cfg(feature = "cudnn")]
mod pool_cudnn_facade;
#[cfg(feature = "cudnn")]
pub use pool_cudnn_facade::*;
// ============================================================================
// Phase 7 Milestone 7.3 — Indexing / scatter / gather (Category L)
// ============================================================================
//
// Six ops: gather + gather_backward, scatter_add, index_select +
// index_select_backward, masked_fill + masked_fill_backward, one_hot,
// nonzero. Index dtype is i32 only (i64 deferred). All FFI signatures
// match the INSTANTIATE macros in `kernels/include/baracuda_indexing.cuh`.
//
// Trailblazer dtype coverage:
// - gather, scatter_add, index_select (FW): f32, f64, i32
// - gather BW, index_select BW: f32, f64 (FP-only — uses atomicAdd)
// - masked_fill FW + BW: f32, f64, i32, bool
// - one_hot FW: input always i32, output f32 / f64 / i32 / bool
// - nonzero FW: input f32 / f64 / i32 / bool, output i32 indices
//
// Out-of-bounds index policy: kernels skip (no write); negative indices
// are treated as out-of-bounds (no wrap-around — PyTorch-style wrap is
// deferred).
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
// ---------- gather (FW) ----------
/// `out[..., j, ...] = src[..., index[..., j, ...], ...]` along
/// `gather_dim`. f32.
pub fn baracuda_kernels_gather_f32_run(
out_numel: i64,
rank: i32,
gather_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_src: *const i64,
stride_index: *const i64,
stride_out: *const i64,
src: *const c_void,
index: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `gather_f32`.
pub fn baracuda_kernels_gather_f32_can_implement(
out_numel: i64,
rank: i32,
gather_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_src: *const i64,
stride_index: *const i64,
stride_out: *const i64,
src: *const c_void,
index: *const c_void,
out: *const c_void,
) -> i32;
/// `gather` along `gather_dim`. f64.
pub fn baracuda_kernels_gather_f64_run(
out_numel: i64,
rank: i32,
gather_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_src: *const i64,
stride_index: *const i64,
stride_out: *const i64,
src: *const c_void,
index: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `gather_f64`.
pub fn baracuda_kernels_gather_f64_can_implement(
out_numel: i64,
rank: i32,
gather_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_src: *const i64,
stride_index: *const i64,
stride_out: *const i64,
src: *const c_void,
index: *const c_void,
out: *const c_void,
) -> i32;
/// `gather` along `gather_dim`. i32.
pub fn baracuda_kernels_gather_i32_run(
out_numel: i64,
rank: i32,
gather_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_src: *const i64,
stride_index: *const i64,
stride_out: *const i64,
src: *const c_void,
index: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `gather_i32`.
pub fn baracuda_kernels_gather_i32_can_implement(
out_numel: i64,
rank: i32,
gather_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_src: *const i64,
stride_index: *const i64,
stride_out: *const i64,
src: *const c_void,
index: *const c_void,
out: *const c_void,
) -> i32;
// ---------- gather_backward (scatter-add into dsrc) ----------
/// `dsrc[..., index[..., j, ...], ...] += dout[..., j, ...]` along
/// `gather_dim`. f32 (atomicAdd).
pub fn baracuda_kernels_gather_backward_f32_run(
out_numel: i64,
rank: i32,
gather_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_dout: *const i64,
stride_index: *const i64,
stride_dsrc: *const i64,
dout: *const c_void,
index: *const c_void,
dsrc: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `gather_backward_f32`.
pub fn baracuda_kernels_gather_backward_f32_can_implement(
out_numel: i64,
rank: i32,
gather_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_dout: *const i64,
stride_index: *const i64,
stride_dsrc: *const i64,
dout: *const c_void,
index: *const c_void,
dsrc: *const c_void,
) -> i32;
/// `gather_backward` — f64 (atomicAdd).
pub fn baracuda_kernels_gather_backward_f64_run(
out_numel: i64,
rank: i32,
gather_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_dout: *const i64,
stride_index: *const i64,
stride_dsrc: *const i64,
dout: *const c_void,
index: *const c_void,
dsrc: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `gather_backward_f64`.
pub fn baracuda_kernels_gather_backward_f64_can_implement(
out_numel: i64,
rank: i32,
gather_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_dout: *const i64,
stride_index: *const i64,
stride_dsrc: *const i64,
dout: *const c_void,
index: *const c_void,
dsrc: *const c_void,
) -> i32;
// ---------- scatter_add ----------
/// `out[..., index[..., j, ...], ...] += updates[..., j, ...]`
/// along `scatter_dim`. f32 (atomicAdd).
pub fn baracuda_kernels_scatter_add_f32_run(
upd_numel: i64,
rank: i32,
scatter_dim: i32,
out_dim_size: i32,
upd_shape: *const i32,
stride_upd: *const i64,
stride_index: *const i64,
stride_out: *const i64,
updates: *const c_void,
index: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `scatter_add_f32`.
pub fn baracuda_kernels_scatter_add_f32_can_implement(
upd_numel: i64,
rank: i32,
scatter_dim: i32,
out_dim_size: i32,
upd_shape: *const i32,
stride_upd: *const i64,
stride_index: *const i64,
stride_out: *const i64,
updates: *const c_void,
index: *const c_void,
out: *const c_void,
) -> i32;
/// `scatter_add` — f64 (atomicAdd).
pub fn baracuda_kernels_scatter_add_f64_run(
upd_numel: i64,
rank: i32,
scatter_dim: i32,
out_dim_size: i32,
upd_shape: *const i32,
stride_upd: *const i64,
stride_index: *const i64,
stride_out: *const i64,
updates: *const c_void,
index: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `scatter_add_f64`.
pub fn baracuda_kernels_scatter_add_f64_can_implement(
upd_numel: i64,
rank: i32,
scatter_dim: i32,
out_dim_size: i32,
upd_shape: *const i32,
stride_upd: *const i64,
stride_index: *const i64,
stride_out: *const i64,
updates: *const c_void,
index: *const c_void,
out: *const c_void,
) -> i32;
// ---------- index_select (FW) ----------
/// `out[..., j, ...] = src[..., idx[j], ...]` along `select_dim`.
/// `idx` is 1-D i32. f32.
pub fn baracuda_kernels_index_select_f32_run(
out_numel: i64,
rank: i32,
select_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_src: *const i64,
stride_out: *const i64,
src: *const c_void,
idx: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `index_select_f32`.
pub fn baracuda_kernels_index_select_f32_can_implement(
out_numel: i64,
rank: i32,
select_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_src: *const i64,
stride_out: *const i64,
src: *const c_void,
idx: *const c_void,
out: *const c_void,
) -> i32;
/// `index_select` — f64.
pub fn baracuda_kernels_index_select_f64_run(
out_numel: i64,
rank: i32,
select_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_src: *const i64,
stride_out: *const i64,
src: *const c_void,
idx: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `index_select_f64`.
pub fn baracuda_kernels_index_select_f64_can_implement(
out_numel: i64,
rank: i32,
select_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_src: *const i64,
stride_out: *const i64,
src: *const c_void,
idx: *const c_void,
out: *const c_void,
) -> i32;
/// `index_select` — i32.
pub fn baracuda_kernels_index_select_i32_run(
out_numel: i64,
rank: i32,
select_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_src: *const i64,
stride_out: *const i64,
src: *const c_void,
idx: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `index_select_i32`.
pub fn baracuda_kernels_index_select_i32_can_implement(
out_numel: i64,
rank: i32,
select_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_src: *const i64,
stride_out: *const i64,
src: *const c_void,
idx: *const c_void,
out: *const c_void,
) -> i32;
// ---------- index_select_backward (scatter-add into dsrc) ----------
/// `dsrc[..., idx[j], ...] += dout[..., j, ...]` along
/// `select_dim`. f32 (atomicAdd).
pub fn baracuda_kernels_index_select_backward_f32_run(
out_numel: i64,
rank: i32,
select_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_dout: *const i64,
stride_dsrc: *const i64,
dout: *const c_void,
idx: *const c_void,
dsrc: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `index_select_backward_f32`.
pub fn baracuda_kernels_index_select_backward_f32_can_implement(
out_numel: i64,
rank: i32,
select_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_dout: *const i64,
stride_dsrc: *const i64,
dout: *const c_void,
idx: *const c_void,
dsrc: *const c_void,
) -> i32;
/// `index_select_backward` — f64 (atomicAdd).
pub fn baracuda_kernels_index_select_backward_f64_run(
out_numel: i64,
rank: i32,
select_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_dout: *const i64,
stride_dsrc: *const i64,
dout: *const c_void,
idx: *const c_void,
dsrc: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `index_select_backward_f64`.
pub fn baracuda_kernels_index_select_backward_f64_can_implement(
out_numel: i64,
rank: i32,
select_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_dout: *const i64,
stride_dsrc: *const i64,
dout: *const c_void,
idx: *const c_void,
dsrc: *const c_void,
) -> i32;
// ---------- masked_fill (FW) ----------
//
// `fill_bits` is a 64-bit payload that the kernel reinterprets into
// T at launch (`__builtin_memcpy` of `sizeof(T)` bytes). Caller
// encodes their fill value into the low bits.
/// `out[i] = mask[i] ? fill_value : src[i]`. f32 (caller passes
/// `fill_value.to_bits() as i64`).
pub fn baracuda_kernels_masked_fill_f32_run(
numel: i64,
src: *const c_void,
mask: *const c_void,
out: *mut c_void,
fill_bits: i64,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `masked_fill_f32`.
pub fn baracuda_kernels_masked_fill_f32_can_implement(
numel: i64,
src: *const c_void,
mask: *const c_void,
out: *const c_void,
fill_bits: i64,
) -> i32;
/// `masked_fill` — f64.
pub fn baracuda_kernels_masked_fill_f64_run(
numel: i64,
src: *const c_void,
mask: *const c_void,
out: *mut c_void,
fill_bits: i64,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `masked_fill_f64`.
pub fn baracuda_kernels_masked_fill_f64_can_implement(
numel: i64,
src: *const c_void,
mask: *const c_void,
out: *const c_void,
fill_bits: i64,
) -> i32;
/// `masked_fill` — i32.
pub fn baracuda_kernels_masked_fill_i32_run(
numel: i64,
src: *const c_void,
mask: *const c_void,
out: *mut c_void,
fill_bits: i64,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `masked_fill_i32`.
pub fn baracuda_kernels_masked_fill_i32_can_implement(
numel: i64,
src: *const c_void,
mask: *const c_void,
out: *const c_void,
fill_bits: i64,
) -> i32;
/// `masked_fill` — bool (u8 storage).
pub fn baracuda_kernels_masked_fill_bool_run(
numel: i64,
src: *const c_void,
mask: *const c_void,
out: *mut c_void,
fill_bits: i64,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `masked_fill_bool`.
pub fn baracuda_kernels_masked_fill_bool_can_implement(
numel: i64,
src: *const c_void,
mask: *const c_void,
out: *const c_void,
fill_bits: i64,
) -> i32;
// ---------- masked_fill_backward ----------
/// `dsrc[i] = mask[i] ? 0 : dout[i]`. f32.
pub fn baracuda_kernels_masked_fill_backward_f32_run(
numel: i64,
dout: *const c_void,
mask: *const c_void,
dsrc: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `masked_fill_backward_f32`.
pub fn baracuda_kernels_masked_fill_backward_f32_can_implement(
numel: i64,
dout: *const c_void,
mask: *const c_void,
dsrc: *const c_void,
) -> i32;
/// `masked_fill_backward` — f64.
pub fn baracuda_kernels_masked_fill_backward_f64_run(
numel: i64,
dout: *const c_void,
mask: *const c_void,
dsrc: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `masked_fill_backward_f64`.
pub fn baracuda_kernels_masked_fill_backward_f64_can_implement(
numel: i64,
dout: *const c_void,
mask: *const c_void,
dsrc: *const c_void,
) -> i32;
/// `masked_fill_backward` — i32.
pub fn baracuda_kernels_masked_fill_backward_i32_run(
numel: i64,
dout: *const c_void,
mask: *const c_void,
dsrc: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `masked_fill_backward_i32`.
pub fn baracuda_kernels_masked_fill_backward_i32_can_implement(
numel: i64,
dout: *const c_void,
mask: *const c_void,
dsrc: *const c_void,
) -> i32;
/// `masked_fill_backward` — bool (u8 storage).
pub fn baracuda_kernels_masked_fill_backward_bool_run(
numel: i64,
dout: *const c_void,
mask: *const c_void,
dsrc: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `masked_fill_backward_bool`.
pub fn baracuda_kernels_masked_fill_backward_bool_can_implement(
numel: i64,
dout: *const c_void,
mask: *const c_void,
dsrc: *const c_void,
) -> i32;
// ---------- one_hot (FW) ----------
/// `out[..., c] = 1 if c == src[...] else 0`. Output last axis has
/// extent `num_classes`. Input dtype is always i32; output is f32.
pub fn baracuda_kernels_one_hot_f32_run(
out_numel: i64,
num_classes: i32,
src: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `one_hot_f32`.
pub fn baracuda_kernels_one_hot_f32_can_implement(
out_numel: i64,
num_classes: i32,
src: *const c_void,
out: *const c_void,
) -> i32;
/// `one_hot` — f64 output.
pub fn baracuda_kernels_one_hot_f64_run(
out_numel: i64,
num_classes: i32,
src: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `one_hot_f64`.
pub fn baracuda_kernels_one_hot_f64_can_implement(
out_numel: i64,
num_classes: i32,
src: *const c_void,
out: *const c_void,
) -> i32;
/// `one_hot` — i32 output.
pub fn baracuda_kernels_one_hot_i32_run(
out_numel: i64,
num_classes: i32,
src: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `one_hot_i32`.
pub fn baracuda_kernels_one_hot_i32_can_implement(
out_numel: i64,
num_classes: i32,
src: *const c_void,
out: *const c_void,
) -> i32;
/// `one_hot` — bool output (u8 storage).
pub fn baracuda_kernels_one_hot_bool_run(
out_numel: i64,
num_classes: i32,
src: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `one_hot_bool`.
pub fn baracuda_kernels_one_hot_bool_can_implement(
out_numel: i64,
num_classes: i32,
src: *const c_void,
out: *const c_void,
) -> i32;
// ---------- nonzero (FW) ----------
//
// Output is a `[max_nz, rank]` i32 coordinate table plus a single
// i32 counter (the launcher zeros it via cudaMemsetAsync before the
// kernel). The kernel uses a global atomic counter, so output order
// is NOT row-major — callers that need sorted output should sort
// afterward.
/// Coordinates where `x[i] != 0`. f32 input.
pub fn baracuda_kernels_nonzero_f32_run(
numel: i64,
rank: i32,
max_nz: i32,
shape: *const i32,
stride_x: *const i64,
x: *const c_void,
out_coords: *mut c_void,
counter: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `nonzero_f32`.
pub fn baracuda_kernels_nonzero_f32_can_implement(
numel: i64,
rank: i32,
max_nz: i32,
shape: *const i32,
stride_x: *const i64,
x: *const c_void,
out_coords: *const c_void,
counter: *const c_void,
) -> i32;
/// `nonzero` — f64 input.
pub fn baracuda_kernels_nonzero_f64_run(
numel: i64,
rank: i32,
max_nz: i32,
shape: *const i32,
stride_x: *const i64,
x: *const c_void,
out_coords: *mut c_void,
counter: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `nonzero_f64`.
pub fn baracuda_kernels_nonzero_f64_can_implement(
numel: i64,
rank: i32,
max_nz: i32,
shape: *const i32,
stride_x: *const i64,
x: *const c_void,
out_coords: *const c_void,
counter: *const c_void,
) -> i32;
/// `nonzero` — i32 input.
pub fn baracuda_kernels_nonzero_i32_run(
numel: i64,
rank: i32,
max_nz: i32,
shape: *const i32,
stride_x: *const i64,
x: *const c_void,
out_coords: *mut c_void,
counter: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `nonzero_i32`.
pub fn baracuda_kernels_nonzero_i32_can_implement(
numel: i64,
rank: i32,
max_nz: i32,
shape: *const i32,
stride_x: *const i64,
x: *const c_void,
out_coords: *const c_void,
counter: *const c_void,
) -> i32;
/// `nonzero` — bool (u8) input.
pub fn baracuda_kernels_nonzero_bool_run(
numel: i64,
rank: i32,
max_nz: i32,
shape: *const i32,
stride_x: *const i64,
x: *const c_void,
out_coords: *mut c_void,
counter: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `nonzero_bool`.
pub fn baracuda_kernels_nonzero_bool_can_implement(
numel: i64,
rank: i32,
max_nz: i32,
shape: *const i32,
stride_x: *const i64,
x: *const c_void,
out_coords: *const c_void,
counter: *const c_void,
) -> i32;
// ---------- i64-index variants (Phase 11.5 / Fuel team feedback #7) ----------
//
// PyTorch defaults the index tensor to int64 for gather / scatter /
// index_select / one_hot. Casting an i64 index tensor down to i32
// on the host before each launch is a measurable overhead the
// safe layer wanted gone. Each `_i64idx_` symbol is identical to
// the i32 sibling above except the `index` / `idx` / `src`
// pointer dereferences int64.
/// `gather` FW — f32, i64 indices.
pub fn baracuda_kernels_gather_i64idx_f32_run(
out_numel: i64,
rank: i32,
gather_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_src: *const i64,
stride_index: *const i64,
stride_out: *const i64,
src: *const c_void,
index: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `gather_i64idx_f32`.
pub fn baracuda_kernels_gather_i64idx_f32_can_implement(
out_numel: i64,
rank: i32,
gather_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_src: *const i64,
stride_index: *const i64,
stride_out: *const i64,
src: *const c_void,
index: *const c_void,
out: *const c_void,
) -> i32;
/// `gather` FW — f64, i64 indices.
pub fn baracuda_kernels_gather_i64idx_f64_run(
out_numel: i64,
rank: i32,
gather_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_src: *const i64,
stride_index: *const i64,
stride_out: *const i64,
src: *const c_void,
index: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `gather_i64idx_f64`.
pub fn baracuda_kernels_gather_i64idx_f64_can_implement(
out_numel: i64,
rank: i32,
gather_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_src: *const i64,
stride_index: *const i64,
stride_out: *const i64,
src: *const c_void,
index: *const c_void,
out: *const c_void,
) -> i32;
/// `gather` FW — i32 values, i64 indices.
pub fn baracuda_kernels_gather_i64idx_i32_run(
out_numel: i64,
rank: i32,
gather_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_src: *const i64,
stride_index: *const i64,
stride_out: *const i64,
src: *const c_void,
index: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `gather_i64idx_i32`.
pub fn baracuda_kernels_gather_i64idx_i32_can_implement(
out_numel: i64,
rank: i32,
gather_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_src: *const i64,
stride_index: *const i64,
stride_out: *const i64,
src: *const c_void,
index: *const c_void,
out: *const c_void,
) -> i32;
/// `gather` BW — f32, i64 indices (atomicAdd).
pub fn baracuda_kernels_gather_backward_i64idx_f32_run(
out_numel: i64,
rank: i32,
gather_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_dout: *const i64,
stride_index: *const i64,
stride_dsrc: *const i64,
dout: *const c_void,
index: *const c_void,
dsrc: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `gather_backward_i64idx_f32`.
pub fn baracuda_kernels_gather_backward_i64idx_f32_can_implement(
out_numel: i64,
rank: i32,
gather_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_dout: *const i64,
stride_index: *const i64,
stride_dsrc: *const i64,
dout: *const c_void,
index: *const c_void,
dsrc: *const c_void,
) -> i32;
/// `gather` BW — f64, i64 indices (atomicAdd).
pub fn baracuda_kernels_gather_backward_i64idx_f64_run(
out_numel: i64,
rank: i32,
gather_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_dout: *const i64,
stride_index: *const i64,
stride_dsrc: *const i64,
dout: *const c_void,
index: *const c_void,
dsrc: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `gather_backward_i64idx_f64`.
pub fn baracuda_kernels_gather_backward_i64idx_f64_can_implement(
out_numel: i64,
rank: i32,
gather_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_dout: *const i64,
stride_index: *const i64,
stride_dsrc: *const i64,
dout: *const c_void,
index: *const c_void,
dsrc: *const c_void,
) -> i32;
/// `scatter_add` — f32, i64 indices (atomicAdd).
pub fn baracuda_kernels_scatter_add_i64idx_f32_run(
upd_numel: i64,
rank: i32,
scatter_dim: i32,
out_dim_size: i32,
upd_shape: *const i32,
stride_upd: *const i64,
stride_index: *const i64,
stride_out: *const i64,
updates: *const c_void,
index: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `scatter_add_i64idx_f32`.
pub fn baracuda_kernels_scatter_add_i64idx_f32_can_implement(
upd_numel: i64,
rank: i32,
scatter_dim: i32,
out_dim_size: i32,
upd_shape: *const i32,
stride_upd: *const i64,
stride_index: *const i64,
stride_out: *const i64,
updates: *const c_void,
index: *const c_void,
out: *const c_void,
) -> i32;
/// `scatter_add` — f64, i64 indices.
pub fn baracuda_kernels_scatter_add_i64idx_f64_run(
upd_numel: i64,
rank: i32,
scatter_dim: i32,
out_dim_size: i32,
upd_shape: *const i32,
stride_upd: *const i64,
stride_index: *const i64,
stride_out: *const i64,
updates: *const c_void,
index: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `scatter_add_i64idx_f64`.
pub fn baracuda_kernels_scatter_add_i64idx_f64_can_implement(
upd_numel: i64,
rank: i32,
scatter_dim: i32,
out_dim_size: i32,
upd_shape: *const i32,
stride_upd: *const i64,
stride_index: *const i64,
stride_out: *const i64,
updates: *const c_void,
index: *const c_void,
out: *const c_void,
) -> i32;
/// `index_select` — f32, i64 indices.
pub fn baracuda_kernels_index_select_i64idx_f32_run(
out_numel: i64,
rank: i32,
select_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_src: *const i64,
stride_out: *const i64,
src: *const c_void,
idx: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `index_select_i64idx_f32`.
pub fn baracuda_kernels_index_select_i64idx_f32_can_implement(
out_numel: i64,
rank: i32,
select_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_src: *const i64,
stride_out: *const i64,
src: *const c_void,
idx: *const c_void,
out: *const c_void,
) -> i32;
/// `index_select` — f64, i64 indices.
pub fn baracuda_kernels_index_select_i64idx_f64_run(
out_numel: i64,
rank: i32,
select_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_src: *const i64,
stride_out: *const i64,
src: *const c_void,
idx: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `index_select_i64idx_f64`.
pub fn baracuda_kernels_index_select_i64idx_f64_can_implement(
out_numel: i64,
rank: i32,
select_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_src: *const i64,
stride_out: *const i64,
src: *const c_void,
idx: *const c_void,
out: *const c_void,
) -> i32;
/// `index_select` — i32 values, i64 indices.
pub fn baracuda_kernels_index_select_i64idx_i32_run(
out_numel: i64,
rank: i32,
select_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_src: *const i64,
stride_out: *const i64,
src: *const c_void,
idx: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `index_select_i64idx_i32`.
pub fn baracuda_kernels_index_select_i64idx_i32_can_implement(
out_numel: i64,
rank: i32,
select_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_src: *const i64,
stride_out: *const i64,
src: *const c_void,
idx: *const c_void,
out: *const c_void,
) -> i32;
/// `index_select` BW — f32, i64 indices.
pub fn baracuda_kernels_index_select_backward_i64idx_f32_run(
out_numel: i64,
rank: i32,
select_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_dout: *const i64,
stride_dsrc: *const i64,
dout: *const c_void,
idx: *const c_void,
dsrc: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `index_select_backward_i64idx_f32`.
pub fn baracuda_kernels_index_select_backward_i64idx_f32_can_implement(
out_numel: i64,
rank: i32,
select_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_dout: *const i64,
stride_dsrc: *const i64,
dout: *const c_void,
idx: *const c_void,
dsrc: *const c_void,
) -> i32;
/// `index_select` BW — f64, i64 indices.
pub fn baracuda_kernels_index_select_backward_i64idx_f64_run(
out_numel: i64,
rank: i32,
select_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_dout: *const i64,
stride_dsrc: *const i64,
dout: *const c_void,
idx: *const c_void,
dsrc: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `index_select_backward_i64idx_f64`.
pub fn baracuda_kernels_index_select_backward_i64idx_f64_can_implement(
out_numel: i64,
rank: i32,
select_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_dout: *const i64,
stride_dsrc: *const i64,
dout: *const c_void,
idx: *const c_void,
dsrc: *const c_void,
) -> i32;
/// `one_hot` — f32 output, i64 input class indices.
pub fn baracuda_kernels_one_hot_i64idx_f32_run(
out_numel: i64,
num_classes: i32,
src: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `one_hot_i64idx_f32`.
pub fn baracuda_kernels_one_hot_i64idx_f32_can_implement(
out_numel: i64,
num_classes: i32,
src: *const c_void,
out: *const c_void,
) -> i32;
/// `one_hot` — f64 output, i64 indices.
pub fn baracuda_kernels_one_hot_i64idx_f64_run(
out_numel: i64,
num_classes: i32,
src: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `one_hot_i64idx_f64`.
pub fn baracuda_kernels_one_hot_i64idx_f64_can_implement(
out_numel: i64,
num_classes: i32,
src: *const c_void,
out: *const c_void,
) -> i32;
/// `one_hot` — i32 output, i64 indices.
pub fn baracuda_kernels_one_hot_i64idx_i32_run(
out_numel: i64,
num_classes: i32,
src: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `one_hot_i64idx_i32`.
pub fn baracuda_kernels_one_hot_i64idx_i32_can_implement(
out_numel: i64,
num_classes: i32,
src: *const c_void,
out: *const c_void,
) -> i32;
/// `one_hot` — bool output, i64 indices.
pub fn baracuda_kernels_one_hot_i64idx_bool_run(
out_numel: i64,
num_classes: i32,
src: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `one_hot_i64idx_bool`.
pub fn baracuda_kernels_one_hot_i64idx_bool_can_implement(
out_numel: i64,
num_classes: i32,
src: *const c_void,
out: *const c_void,
) -> i32;
/// `nonzero` — f32 input, i64 output coords.
pub fn baracuda_kernels_nonzero_i64idx_f32_run(
numel: i64,
rank: i32,
max_nz: i32,
shape: *const i32,
stride_x: *const i64,
x: *const c_void,
out_coords: *mut c_void,
counter: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `nonzero_i64idx_f32`.
pub fn baracuda_kernels_nonzero_i64idx_f32_can_implement(
numel: i64,
rank: i32,
max_nz: i32,
shape: *const i32,
stride_x: *const i64,
x: *const c_void,
out_coords: *const c_void,
counter: *const c_void,
) -> i32;
/// `nonzero` — f64 input, i64 output coords.
pub fn baracuda_kernels_nonzero_i64idx_f64_run(
numel: i64,
rank: i32,
max_nz: i32,
shape: *const i32,
stride_x: *const i64,
x: *const c_void,
out_coords: *mut c_void,
counter: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `nonzero_i64idx_f64`.
pub fn baracuda_kernels_nonzero_i64idx_f64_can_implement(
numel: i64,
rank: i32,
max_nz: i32,
shape: *const i32,
stride_x: *const i64,
x: *const c_void,
out_coords: *const c_void,
counter: *const c_void,
) -> i32;
/// `nonzero` — i32 input, i64 output coords.
pub fn baracuda_kernels_nonzero_i64idx_i32_run(
numel: i64,
rank: i32,
max_nz: i32,
shape: *const i32,
stride_x: *const i64,
x: *const c_void,
out_coords: *mut c_void,
counter: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `nonzero_i64idx_i32`.
pub fn baracuda_kernels_nonzero_i64idx_i32_can_implement(
numel: i64,
rank: i32,
max_nz: i32,
shape: *const i32,
stride_x: *const i64,
x: *const c_void,
out_coords: *const c_void,
counter: *const c_void,
) -> i32;
/// `nonzero` — bool input, i64 output coords.
pub fn baracuda_kernels_nonzero_i64idx_bool_run(
numel: i64,
rank: i32,
max_nz: i32,
shape: *const i32,
stride_x: *const i64,
x: *const c_void,
out_coords: *mut c_void,
counter: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `nonzero_i64idx_bool`.
pub fn baracuda_kernels_nonzero_i64idx_bool_can_implement(
numel: i64,
rank: i32,
max_nz: i32,
shape: *const i32,
stride_x: *const i64,
x: *const c_void,
out_coords: *const c_void,
counter: *const c_void,
) -> i32;
// ---------- Phase 39 (Fuel 6c.4 Gap 5) — scatter (pure assign) + index_add ----------
//
// Two new ops broaden the indexing matrix:
// * `scatter` — `out[..., index[..., j, ...], ...] = updates[..., j, ...]`
// (NO accumulation; last writer wins on duplicate-target races).
// Distinct from `scatter_add` above. 4 FP dtypes × {i32, i64 idx}.
// * `index_add` — `dst[idx[i], ...] += src[i, ...]` along `add_dim`
// (atomicAdd-Σ accumulation; algorithmically identical to
// `index_select_backward` but exposed under a non-autograd-flavored
// name + with f16 / bf16 dtype fanout).
//
// Out-of-bounds and negative index policy matches the rest of the
// Indexing family: skip the write (no PyTorch-style wrap).
// -- scatter (pure assign) FW --
/// `scatter` — `out[index] = updates`, f32, i32 idx. NO accumulation.
pub fn baracuda_kernels_scatter_f32_run(
upd_numel: i64,
rank: i32,
scatter_dim: i32,
out_dim_size: i32,
upd_shape: *const i32,
stride_upd: *const i64,
stride_index: *const i64,
stride_out: *const i64,
updates: *const c_void,
index: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `scatter_f32`.
pub fn baracuda_kernels_scatter_f32_can_implement(
upd_numel: i64,
rank: i32,
scatter_dim: i32,
out_dim_size: i32,
upd_shape: *const i32,
stride_upd: *const i64,
stride_index: *const i64,
stride_out: *const i64,
updates: *const c_void,
index: *const c_void,
out: *const c_void,
) -> i32;
/// `scatter` — f64, i32 idx.
pub fn baracuda_kernels_scatter_f64_run(
upd_numel: i64,
rank: i32,
scatter_dim: i32,
out_dim_size: i32,
upd_shape: *const i32,
stride_upd: *const i64,
stride_index: *const i64,
stride_out: *const i64,
updates: *const c_void,
index: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `scatter_f64`.
pub fn baracuda_kernels_scatter_f64_can_implement(
upd_numel: i64,
rank: i32,
scatter_dim: i32,
out_dim_size: i32,
upd_shape: *const i32,
stride_upd: *const i64,
stride_index: *const i64,
stride_out: *const i64,
updates: *const c_void,
index: *const c_void,
out: *const c_void,
) -> i32;
/// `scatter` — f16, i32 idx.
pub fn baracuda_kernels_scatter_f16_run(
upd_numel: i64,
rank: i32,
scatter_dim: i32,
out_dim_size: i32,
upd_shape: *const i32,
stride_upd: *const i64,
stride_index: *const i64,
stride_out: *const i64,
updates: *const c_void,
index: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `scatter_f16`.
pub fn baracuda_kernels_scatter_f16_can_implement(
upd_numel: i64,
rank: i32,
scatter_dim: i32,
out_dim_size: i32,
upd_shape: *const i32,
stride_upd: *const i64,
stride_index: *const i64,
stride_out: *const i64,
updates: *const c_void,
index: *const c_void,
out: *const c_void,
) -> i32;
/// `scatter` — bf16, i32 idx.
pub fn baracuda_kernels_scatter_bf16_run(
upd_numel: i64,
rank: i32,
scatter_dim: i32,
out_dim_size: i32,
upd_shape: *const i32,
stride_upd: *const i64,
stride_index: *const i64,
stride_out: *const i64,
updates: *const c_void,
index: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `scatter_bf16`.
pub fn baracuda_kernels_scatter_bf16_can_implement(
upd_numel: i64,
rank: i32,
scatter_dim: i32,
out_dim_size: i32,
upd_shape: *const i32,
stride_upd: *const i64,
stride_index: *const i64,
stride_out: *const i64,
updates: *const c_void,
index: *const c_void,
out: *const c_void,
) -> i32;
/// `scatter` — f32, i64 idx.
pub fn baracuda_kernels_scatter_i64idx_f32_run(
upd_numel: i64,
rank: i32,
scatter_dim: i32,
out_dim_size: i32,
upd_shape: *const i32,
stride_upd: *const i64,
stride_index: *const i64,
stride_out: *const i64,
updates: *const c_void,
index: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `scatter_i64idx_f32`.
pub fn baracuda_kernels_scatter_i64idx_f32_can_implement(
upd_numel: i64,
rank: i32,
scatter_dim: i32,
out_dim_size: i32,
upd_shape: *const i32,
stride_upd: *const i64,
stride_index: *const i64,
stride_out: *const i64,
updates: *const c_void,
index: *const c_void,
out: *const c_void,
) -> i32;
/// `scatter` — f64, i64 idx.
pub fn baracuda_kernels_scatter_i64idx_f64_run(
upd_numel: i64,
rank: i32,
scatter_dim: i32,
out_dim_size: i32,
upd_shape: *const i32,
stride_upd: *const i64,
stride_index: *const i64,
stride_out: *const i64,
updates: *const c_void,
index: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `scatter_i64idx_f64`.
pub fn baracuda_kernels_scatter_i64idx_f64_can_implement(
upd_numel: i64,
rank: i32,
scatter_dim: i32,
out_dim_size: i32,
upd_shape: *const i32,
stride_upd: *const i64,
stride_index: *const i64,
stride_out: *const i64,
updates: *const c_void,
index: *const c_void,
out: *const c_void,
) -> i32;
/// `scatter` — f16, i64 idx.
pub fn baracuda_kernels_scatter_i64idx_f16_run(
upd_numel: i64,
rank: i32,
scatter_dim: i32,
out_dim_size: i32,
upd_shape: *const i32,
stride_upd: *const i64,
stride_index: *const i64,
stride_out: *const i64,
updates: *const c_void,
index: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `scatter_i64idx_f16`.
pub fn baracuda_kernels_scatter_i64idx_f16_can_implement(
upd_numel: i64,
rank: i32,
scatter_dim: i32,
out_dim_size: i32,
upd_shape: *const i32,
stride_upd: *const i64,
stride_index: *const i64,
stride_out: *const i64,
updates: *const c_void,
index: *const c_void,
out: *const c_void,
) -> i32;
/// `scatter` — bf16, i64 idx.
pub fn baracuda_kernels_scatter_i64idx_bf16_run(
upd_numel: i64,
rank: i32,
scatter_dim: i32,
out_dim_size: i32,
upd_shape: *const i32,
stride_upd: *const i64,
stride_index: *const i64,
stride_out: *const i64,
updates: *const c_void,
index: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `scatter_i64idx_bf16`.
pub fn baracuda_kernels_scatter_i64idx_bf16_can_implement(
upd_numel: i64,
rank: i32,
scatter_dim: i32,
out_dim_size: i32,
upd_shape: *const i32,
stride_upd: *const i64,
stride_index: *const i64,
stride_out: *const i64,
updates: *const c_void,
index: *const c_void,
out: *const c_void,
) -> i32;
// -- index_add FW --
/// `index_add` — `dst[idx[i], ...] += src[i, ...]`, f32, i32 idx.
pub fn baracuda_kernels_index_add_f32_run(
src_numel: i64,
rank: i32,
add_dim: i32,
dst_dim_size: i32,
src_shape: *const i32,
stride_src: *const i64,
stride_dst: *const i64,
src: *const c_void,
idx: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `index_add_f32`.
pub fn baracuda_kernels_index_add_f32_can_implement(
src_numel: i64,
rank: i32,
add_dim: i32,
dst_dim_size: i32,
src_shape: *const i32,
stride_src: *const i64,
stride_dst: *const i64,
src: *const c_void,
idx: *const c_void,
dst: *const c_void,
) -> i32;
/// `index_add` — f64, i32 idx.
pub fn baracuda_kernels_index_add_f64_run(
src_numel: i64,
rank: i32,
add_dim: i32,
dst_dim_size: i32,
src_shape: *const i32,
stride_src: *const i64,
stride_dst: *const i64,
src: *const c_void,
idx: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `index_add_f64`.
pub fn baracuda_kernels_index_add_f64_can_implement(
src_numel: i64,
rank: i32,
add_dim: i32,
dst_dim_size: i32,
src_shape: *const i32,
stride_src: *const i64,
stride_dst: *const i64,
src: *const c_void,
idx: *const c_void,
dst: *const c_void,
) -> i32;
/// `index_add` — f16, i32 idx. Uses `atomicCAS`-via-
/// `baracuda::atomic::add<__half>` (deterministic per-thread arithmetic
/// regardless of CUDA toolkit; non-deterministic accumulation order).
pub fn baracuda_kernels_index_add_f16_run(
src_numel: i64,
rank: i32,
add_dim: i32,
dst_dim_size: i32,
src_shape: *const i32,
stride_src: *const i64,
stride_dst: *const i64,
src: *const c_void,
idx: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `index_add_f16`.
pub fn baracuda_kernels_index_add_f16_can_implement(
src_numel: i64,
rank: i32,
add_dim: i32,
dst_dim_size: i32,
src_shape: *const i32,
stride_src: *const i64,
stride_dst: *const i64,
src: *const c_void,
idx: *const c_void,
dst: *const c_void,
) -> i32;
/// `index_add` — bf16, i32 idx. `atomicCAS`-via-
/// `baracuda::atomic::add<__nv_bfloat16>` (same caveats as f16).
pub fn baracuda_kernels_index_add_bf16_run(
src_numel: i64,
rank: i32,
add_dim: i32,
dst_dim_size: i32,
src_shape: *const i32,
stride_src: *const i64,
stride_dst: *const i64,
src: *const c_void,
idx: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `index_add_bf16`.
pub fn baracuda_kernels_index_add_bf16_can_implement(
src_numel: i64,
rank: i32,
add_dim: i32,
dst_dim_size: i32,
src_shape: *const i32,
stride_src: *const i64,
stride_dst: *const i64,
src: *const c_void,
idx: *const c_void,
dst: *const c_void,
) -> i32;
/// `index_add` — f32, i64 idx.
pub fn baracuda_kernels_index_add_i64idx_f32_run(
src_numel: i64,
rank: i32,
add_dim: i32,
dst_dim_size: i32,
src_shape: *const i32,
stride_src: *const i64,
stride_dst: *const i64,
src: *const c_void,
idx: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `index_add_i64idx_f32`.
pub fn baracuda_kernels_index_add_i64idx_f32_can_implement(
src_numel: i64,
rank: i32,
add_dim: i32,
dst_dim_size: i32,
src_shape: *const i32,
stride_src: *const i64,
stride_dst: *const i64,
src: *const c_void,
idx: *const c_void,
dst: *const c_void,
) -> i32;
/// `index_add` — f64, i64 idx.
pub fn baracuda_kernels_index_add_i64idx_f64_run(
src_numel: i64,
rank: i32,
add_dim: i32,
dst_dim_size: i32,
src_shape: *const i32,
stride_src: *const i64,
stride_dst: *const i64,
src: *const c_void,
idx: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `index_add_i64idx_f64`.
pub fn baracuda_kernels_index_add_i64idx_f64_can_implement(
src_numel: i64,
rank: i32,
add_dim: i32,
dst_dim_size: i32,
src_shape: *const i32,
stride_src: *const i64,
stride_dst: *const i64,
src: *const c_void,
idx: *const c_void,
dst: *const c_void,
) -> i32;
/// `index_add` — f16, i64 idx.
pub fn baracuda_kernels_index_add_i64idx_f16_run(
src_numel: i64,
rank: i32,
add_dim: i32,
dst_dim_size: i32,
src_shape: *const i32,
stride_src: *const i64,
stride_dst: *const i64,
src: *const c_void,
idx: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `index_add_i64idx_f16`.
pub fn baracuda_kernels_index_add_i64idx_f16_can_implement(
src_numel: i64,
rank: i32,
add_dim: i32,
dst_dim_size: i32,
src_shape: *const i32,
stride_src: *const i64,
stride_dst: *const i64,
src: *const c_void,
idx: *const c_void,
dst: *const c_void,
) -> i32;
/// `index_add` — bf16, i64 idx.
pub fn baracuda_kernels_index_add_i64idx_bf16_run(
src_numel: i64,
rank: i32,
add_dim: i32,
dst_dim_size: i32,
src_shape: *const i32,
stride_src: *const i64,
stride_dst: *const i64,
src: *const c_void,
idx: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `index_add_i64idx_bf16`.
pub fn baracuda_kernels_index_add_i64idx_bf16_can_implement(
src_numel: i64,
rank: i32,
add_dim: i32,
dst_dim_size: i32,
src_shape: *const i32,
stride_src: *const i64,
stride_dst: *const i64,
src: *const c_void,
idx: *const c_void,
dst: *const c_void,
) -> i32;
// -- gather (u8 idx extras) --
//
// u8 idx is **not** in the Rust `IndexElement` sealed trait today;
// these symbols are FFI-only for callers like Fuel that ship their
// own index-dtype dispatch. Promotion to a full IndexElement impl
// is tracked as a follow-up.
/// `gather` FW — f32, u8 idx.
pub fn baracuda_kernels_gather_u8idx_f32_run(
out_numel: i64,
rank: i32,
gather_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_src: *const i64,
stride_index: *const i64,
stride_out: *const i64,
src: *const c_void,
index: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `gather_u8idx_f32`.
pub fn baracuda_kernels_gather_u8idx_f32_can_implement(
out_numel: i64,
rank: i32,
gather_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_src: *const i64,
stride_index: *const i64,
stride_out: *const i64,
src: *const c_void,
index: *const c_void,
out: *const c_void,
) -> i32;
/// `gather` FW — f64, u8 idx.
pub fn baracuda_kernels_gather_u8idx_f64_run(
out_numel: i64,
rank: i32,
gather_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_src: *const i64,
stride_index: *const i64,
stride_out: *const i64,
src: *const c_void,
index: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `gather_u8idx_f64`.
pub fn baracuda_kernels_gather_u8idx_f64_can_implement(
out_numel: i64,
rank: i32,
gather_dim: i32,
src_dim_size: i32,
out_shape: *const i32,
stride_src: *const i64,
stride_index: *const i64,
stride_out: *const i64,
src: *const c_void,
index: *const c_void,
out: *const c_void,
) -> i32;
// ---------- Phase 40 (Fuel 6c.4 Gap 6b spillover) ----------
// Integer value-dtype fanout for indexing ops. Read-only ops
// (`gather`, `index_select`) and pure-assign `scatter` cover the
// full {u8, i8, u16, i16, u32, i32, i64} matrix. `index_add` is
// gated to value dtypes with native CUDA `atomicAdd` (i32, u32,
// i64-via-ull-reinterpret).
//
// The FFI signatures match the existing fp-dtype counterparts —
// see `BARACUDA_KERNELS_*_INSTANTIATE` macros in
// `kernels/include/baracuda_indexing.cuh`. The integer specs are
// currently FFI-only (not surfaced via the Rust `IndexElement` /
// `Element` trait dispatch); promotion to plan-layer dtype
// matching is tracked as a follow-up.
// -- gather FW (integer value-dtype) --
/// `baracuda_kernels_gather_u8_run` (baracuda kernels gather u8 run).
pub fn baracuda_kernels_gather_u8_run(
out_numel: i64, rank: i32, gather_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64,
stride_index: *const i64, stride_out: *const i64,
src: *const c_void, index: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gather_u8_can_implement` (baracuda kernels gather u8 can implement).
pub fn baracuda_kernels_gather_u8_can_implement(
out_numel: i64, rank: i32, gather_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64,
stride_index: *const i64, stride_out: *const i64,
src: *const c_void, index: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_gather_i8_run` (baracuda kernels gather i8 run).
pub fn baracuda_kernels_gather_i8_run(
out_numel: i64, rank: i32, gather_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64,
stride_index: *const i64, stride_out: *const i64,
src: *const c_void, index: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gather_i8_can_implement` (baracuda kernels gather i8 can implement).
pub fn baracuda_kernels_gather_i8_can_implement(
out_numel: i64, rank: i32, gather_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64,
stride_index: *const i64, stride_out: *const i64,
src: *const c_void, index: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_gather_u16_run` (baracuda kernels gather u16 run).
pub fn baracuda_kernels_gather_u16_run(
out_numel: i64, rank: i32, gather_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64,
stride_index: *const i64, stride_out: *const i64,
src: *const c_void, index: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gather_u16_can_implement` (baracuda kernels gather u16 can implement).
pub fn baracuda_kernels_gather_u16_can_implement(
out_numel: i64, rank: i32, gather_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64,
stride_index: *const i64, stride_out: *const i64,
src: *const c_void, index: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_gather_i16_run` (baracuda kernels gather i16 run).
pub fn baracuda_kernels_gather_i16_run(
out_numel: i64, rank: i32, gather_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64,
stride_index: *const i64, stride_out: *const i64,
src: *const c_void, index: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gather_i16_can_implement` (baracuda kernels gather i16 can implement).
pub fn baracuda_kernels_gather_i16_can_implement(
out_numel: i64, rank: i32, gather_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64,
stride_index: *const i64, stride_out: *const i64,
src: *const c_void, index: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_gather_u32_run` (baracuda kernels gather u32 run).
pub fn baracuda_kernels_gather_u32_run(
out_numel: i64, rank: i32, gather_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64,
stride_index: *const i64, stride_out: *const i64,
src: *const c_void, index: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gather_u32_can_implement` (baracuda kernels gather u32 can implement).
pub fn baracuda_kernels_gather_u32_can_implement(
out_numel: i64, rank: i32, gather_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64,
stride_index: *const i64, stride_out: *const i64,
src: *const c_void, index: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_gather_i64_run` (baracuda kernels gather i64 run).
pub fn baracuda_kernels_gather_i64_run(
out_numel: i64, rank: i32, gather_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64,
stride_index: *const i64, stride_out: *const i64,
src: *const c_void, index: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gather_i64_can_implement` (baracuda kernels gather i64 can implement).
pub fn baracuda_kernels_gather_i64_can_implement(
out_numel: i64, rank: i32, gather_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64,
stride_index: *const i64, stride_out: *const i64,
src: *const c_void, index: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_gather_i64idx_u8_run` (baracuda kernels gather i64idx u8 run).
pub fn baracuda_kernels_gather_i64idx_u8_run(
out_numel: i64, rank: i32, gather_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64,
stride_index: *const i64, stride_out: *const i64,
src: *const c_void, index: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gather_i64idx_u8_can_implement` (baracuda kernels gather i64idx u8 can implement).
pub fn baracuda_kernels_gather_i64idx_u8_can_implement(
out_numel: i64, rank: i32, gather_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64,
stride_index: *const i64, stride_out: *const i64,
src: *const c_void, index: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_gather_i64idx_i8_run` (baracuda kernels gather i64idx i8 run).
pub fn baracuda_kernels_gather_i64idx_i8_run(
out_numel: i64, rank: i32, gather_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64,
stride_index: *const i64, stride_out: *const i64,
src: *const c_void, index: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gather_i64idx_i8_can_implement` (baracuda kernels gather i64idx i8 can implement).
pub fn baracuda_kernels_gather_i64idx_i8_can_implement(
out_numel: i64, rank: i32, gather_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64,
stride_index: *const i64, stride_out: *const i64,
src: *const c_void, index: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_gather_i64idx_u16_run` (baracuda kernels gather i64idx u16 run).
pub fn baracuda_kernels_gather_i64idx_u16_run(
out_numel: i64, rank: i32, gather_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64,
stride_index: *const i64, stride_out: *const i64,
src: *const c_void, index: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gather_i64idx_u16_can_implement` (baracuda kernels gather i64idx u16 can implement).
pub fn baracuda_kernels_gather_i64idx_u16_can_implement(
out_numel: i64, rank: i32, gather_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64,
stride_index: *const i64, stride_out: *const i64,
src: *const c_void, index: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_gather_i64idx_i16_run` (baracuda kernels gather i64idx i16 run).
pub fn baracuda_kernels_gather_i64idx_i16_run(
out_numel: i64, rank: i32, gather_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64,
stride_index: *const i64, stride_out: *const i64,
src: *const c_void, index: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gather_i64idx_i16_can_implement` (baracuda kernels gather i64idx i16 can implement).
pub fn baracuda_kernels_gather_i64idx_i16_can_implement(
out_numel: i64, rank: i32, gather_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64,
stride_index: *const i64, stride_out: *const i64,
src: *const c_void, index: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_gather_i64idx_u32_run` (baracuda kernels gather i64idx u32 run).
pub fn baracuda_kernels_gather_i64idx_u32_run(
out_numel: i64, rank: i32, gather_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64,
stride_index: *const i64, stride_out: *const i64,
src: *const c_void, index: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gather_i64idx_u32_can_implement` (baracuda kernels gather i64idx u32 can implement).
pub fn baracuda_kernels_gather_i64idx_u32_can_implement(
out_numel: i64, rank: i32, gather_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64,
stride_index: *const i64, stride_out: *const i64,
src: *const c_void, index: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_gather_i64idx_i64_run` (baracuda kernels gather i64idx i64 run).
pub fn baracuda_kernels_gather_i64idx_i64_run(
out_numel: i64, rank: i32, gather_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64,
stride_index: *const i64, stride_out: *const i64,
src: *const c_void, index: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_gather_i64idx_i64_can_implement` (baracuda kernels gather i64idx i64 can implement).
pub fn baracuda_kernels_gather_i64idx_i64_can_implement(
out_numel: i64, rank: i32, gather_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64,
stride_index: *const i64, stride_out: *const i64,
src: *const c_void, index: *const c_void, out: *const c_void,
) -> i32;
// -- index_select FW (integer value-dtype) --
/// `baracuda_kernels_index_select_u8_run` (baracuda kernels index select u8 run).
pub fn baracuda_kernels_index_select_u8_run(
out_numel: i64, rank: i32, select_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64, stride_out: *const i64,
src: *const c_void, idx: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_index_select_u8_can_implement` (baracuda kernels index select u8 can implement).
pub fn baracuda_kernels_index_select_u8_can_implement(
out_numel: i64, rank: i32, select_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64, stride_out: *const i64,
src: *const c_void, idx: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_index_select_i8_run` (baracuda kernels index select i8 run).
pub fn baracuda_kernels_index_select_i8_run(
out_numel: i64, rank: i32, select_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64, stride_out: *const i64,
src: *const c_void, idx: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_index_select_i8_can_implement` (baracuda kernels index select i8 can implement).
pub fn baracuda_kernels_index_select_i8_can_implement(
out_numel: i64, rank: i32, select_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64, stride_out: *const i64,
src: *const c_void, idx: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_index_select_u16_run` (baracuda kernels index select u16 run).
pub fn baracuda_kernels_index_select_u16_run(
out_numel: i64, rank: i32, select_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64, stride_out: *const i64,
src: *const c_void, idx: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_index_select_u16_can_implement` (baracuda kernels index select u16 can implement).
pub fn baracuda_kernels_index_select_u16_can_implement(
out_numel: i64, rank: i32, select_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64, stride_out: *const i64,
src: *const c_void, idx: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_index_select_i16_run` (baracuda kernels index select i16 run).
pub fn baracuda_kernels_index_select_i16_run(
out_numel: i64, rank: i32, select_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64, stride_out: *const i64,
src: *const c_void, idx: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_index_select_i16_can_implement` (baracuda kernels index select i16 can implement).
pub fn baracuda_kernels_index_select_i16_can_implement(
out_numel: i64, rank: i32, select_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64, stride_out: *const i64,
src: *const c_void, idx: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_index_select_u32_run` (baracuda kernels index select u32 run).
pub fn baracuda_kernels_index_select_u32_run(
out_numel: i64, rank: i32, select_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64, stride_out: *const i64,
src: *const c_void, idx: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_index_select_u32_can_implement` (baracuda kernels index select u32 can implement).
pub fn baracuda_kernels_index_select_u32_can_implement(
out_numel: i64, rank: i32, select_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64, stride_out: *const i64,
src: *const c_void, idx: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_index_select_i64_run` (baracuda kernels index select i64 run).
pub fn baracuda_kernels_index_select_i64_run(
out_numel: i64, rank: i32, select_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64, stride_out: *const i64,
src: *const c_void, idx: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_index_select_i64_can_implement` (baracuda kernels index select i64 can implement).
pub fn baracuda_kernels_index_select_i64_can_implement(
out_numel: i64, rank: i32, select_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64, stride_out: *const i64,
src: *const c_void, idx: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_index_select_i64idx_u8_run` (baracuda kernels index select i64idx u8 run).
pub fn baracuda_kernels_index_select_i64idx_u8_run(
out_numel: i64, rank: i32, select_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64, stride_out: *const i64,
src: *const c_void, idx: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_index_select_i64idx_u8_can_implement` (baracuda kernels index select i64idx u8 can implement).
pub fn baracuda_kernels_index_select_i64idx_u8_can_implement(
out_numel: i64, rank: i32, select_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64, stride_out: *const i64,
src: *const c_void, idx: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_index_select_i64idx_i8_run` (baracuda kernels index select i64idx i8 run).
pub fn baracuda_kernels_index_select_i64idx_i8_run(
out_numel: i64, rank: i32, select_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64, stride_out: *const i64,
src: *const c_void, idx: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_index_select_i64idx_i8_can_implement` (baracuda kernels index select i64idx i8 can implement).
pub fn baracuda_kernels_index_select_i64idx_i8_can_implement(
out_numel: i64, rank: i32, select_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64, stride_out: *const i64,
src: *const c_void, idx: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_index_select_i64idx_u16_run` (baracuda kernels index select i64idx u16 run).
pub fn baracuda_kernels_index_select_i64idx_u16_run(
out_numel: i64, rank: i32, select_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64, stride_out: *const i64,
src: *const c_void, idx: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_index_select_i64idx_u16_can_implement` (baracuda kernels index select i64idx u16 can implement).
pub fn baracuda_kernels_index_select_i64idx_u16_can_implement(
out_numel: i64, rank: i32, select_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64, stride_out: *const i64,
src: *const c_void, idx: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_index_select_i64idx_i16_run` (baracuda kernels index select i64idx i16 run).
pub fn baracuda_kernels_index_select_i64idx_i16_run(
out_numel: i64, rank: i32, select_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64, stride_out: *const i64,
src: *const c_void, idx: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_index_select_i64idx_i16_can_implement` (baracuda kernels index select i64idx i16 can implement).
pub fn baracuda_kernels_index_select_i64idx_i16_can_implement(
out_numel: i64, rank: i32, select_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64, stride_out: *const i64,
src: *const c_void, idx: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_index_select_i64idx_u32_run` (baracuda kernels index select i64idx u32 run).
pub fn baracuda_kernels_index_select_i64idx_u32_run(
out_numel: i64, rank: i32, select_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64, stride_out: *const i64,
src: *const c_void, idx: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_index_select_i64idx_u32_can_implement` (baracuda kernels index select i64idx u32 can implement).
pub fn baracuda_kernels_index_select_i64idx_u32_can_implement(
out_numel: i64, rank: i32, select_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64, stride_out: *const i64,
src: *const c_void, idx: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_index_select_i64idx_i64_run` (baracuda kernels index select i64idx i64 run).
pub fn baracuda_kernels_index_select_i64idx_i64_run(
out_numel: i64, rank: i32, select_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64, stride_out: *const i64,
src: *const c_void, idx: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_index_select_i64idx_i64_can_implement` (baracuda kernels index select i64idx i64 can implement).
pub fn baracuda_kernels_index_select_i64idx_i64_can_implement(
out_numel: i64, rank: i32, select_dim: i32, src_dim_size: i32,
out_shape: *const i32, stride_src: *const i64, stride_out: *const i64,
src: *const c_void, idx: *const c_void, out: *const c_void,
) -> i32;
// -- scatter (pure-assign; integer value-dtype) --
/// `baracuda_kernels_scatter_u8_run` (baracuda kernels scatter u8 run).
pub fn baracuda_kernels_scatter_u8_run(
upd_numel: i64, rank: i32, scatter_dim: i32, out_dim_size: i32,
upd_shape: *const i32, stride_upd: *const i64,
stride_index: *const i64, stride_out: *const i64,
updates: *const c_void, index: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_scatter_u8_can_implement` (baracuda kernels scatter u8 can implement).
pub fn baracuda_kernels_scatter_u8_can_implement(
upd_numel: i64, rank: i32, scatter_dim: i32, out_dim_size: i32,
upd_shape: *const i32, stride_upd: *const i64,
stride_index: *const i64, stride_out: *const i64,
updates: *const c_void, index: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_scatter_i8_run` (baracuda kernels scatter i8 run).
pub fn baracuda_kernels_scatter_i8_run(
upd_numel: i64, rank: i32, scatter_dim: i32, out_dim_size: i32,
upd_shape: *const i32, stride_upd: *const i64,
stride_index: *const i64, stride_out: *const i64,
updates: *const c_void, index: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_scatter_i8_can_implement` (baracuda kernels scatter i8 can implement).
pub fn baracuda_kernels_scatter_i8_can_implement(
upd_numel: i64, rank: i32, scatter_dim: i32, out_dim_size: i32,
upd_shape: *const i32, stride_upd: *const i64,
stride_index: *const i64, stride_out: *const i64,
updates: *const c_void, index: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_scatter_u16_run` (baracuda kernels scatter u16 run).
pub fn baracuda_kernels_scatter_u16_run(
upd_numel: i64, rank: i32, scatter_dim: i32, out_dim_size: i32,
upd_shape: *const i32, stride_upd: *const i64,
stride_index: *const i64, stride_out: *const i64,
updates: *const c_void, index: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_scatter_u16_can_implement` (baracuda kernels scatter u16 can implement).
pub fn baracuda_kernels_scatter_u16_can_implement(
upd_numel: i64, rank: i32, scatter_dim: i32, out_dim_size: i32,
upd_shape: *const i32, stride_upd: *const i64,
stride_index: *const i64, stride_out: *const i64,
updates: *const c_void, index: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_scatter_i16_run` (baracuda kernels scatter i16 run).
pub fn baracuda_kernels_scatter_i16_run(
upd_numel: i64, rank: i32, scatter_dim: i32, out_dim_size: i32,
upd_shape: *const i32, stride_upd: *const i64,
stride_index: *const i64, stride_out: *const i64,
updates: *const c_void, index: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_scatter_i16_can_implement` (baracuda kernels scatter i16 can implement).
pub fn baracuda_kernels_scatter_i16_can_implement(
upd_numel: i64, rank: i32, scatter_dim: i32, out_dim_size: i32,
upd_shape: *const i32, stride_upd: *const i64,
stride_index: *const i64, stride_out: *const i64,
updates: *const c_void, index: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_scatter_u32_run` (baracuda kernels scatter u32 run).
pub fn baracuda_kernels_scatter_u32_run(
upd_numel: i64, rank: i32, scatter_dim: i32, out_dim_size: i32,
upd_shape: *const i32, stride_upd: *const i64,
stride_index: *const i64, stride_out: *const i64,
updates: *const c_void, index: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_scatter_u32_can_implement` (baracuda kernels scatter u32 can implement).
pub fn baracuda_kernels_scatter_u32_can_implement(
upd_numel: i64, rank: i32, scatter_dim: i32, out_dim_size: i32,
upd_shape: *const i32, stride_upd: *const i64,
stride_index: *const i64, stride_out: *const i64,
updates: *const c_void, index: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_scatter_i32_run` (baracuda kernels scatter i32 run).
pub fn baracuda_kernels_scatter_i32_run(
upd_numel: i64, rank: i32, scatter_dim: i32, out_dim_size: i32,
upd_shape: *const i32, stride_upd: *const i64,
stride_index: *const i64, stride_out: *const i64,
updates: *const c_void, index: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_scatter_i32_can_implement` (baracuda kernels scatter i32 can implement).
pub fn baracuda_kernels_scatter_i32_can_implement(
upd_numel: i64, rank: i32, scatter_dim: i32, out_dim_size: i32,
upd_shape: *const i32, stride_upd: *const i64,
stride_index: *const i64, stride_out: *const i64,
updates: *const c_void, index: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_scatter_i64_run` (baracuda kernels scatter i64 run).
pub fn baracuda_kernels_scatter_i64_run(
upd_numel: i64, rank: i32, scatter_dim: i32, out_dim_size: i32,
upd_shape: *const i32, stride_upd: *const i64,
stride_index: *const i64, stride_out: *const i64,
updates: *const c_void, index: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_scatter_i64_can_implement` (baracuda kernels scatter i64 can implement).
pub fn baracuda_kernels_scatter_i64_can_implement(
upd_numel: i64, rank: i32, scatter_dim: i32, out_dim_size: i32,
upd_shape: *const i32, stride_upd: *const i64,
stride_index: *const i64, stride_out: *const i64,
updates: *const c_void, index: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_scatter_i64idx_u8_run` (baracuda kernels scatter i64idx u8 run).
pub fn baracuda_kernels_scatter_i64idx_u8_run(
upd_numel: i64, rank: i32, scatter_dim: i32, out_dim_size: i32,
upd_shape: *const i32, stride_upd: *const i64,
stride_index: *const i64, stride_out: *const i64,
updates: *const c_void, index: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_scatter_i64idx_u8_can_implement` (baracuda kernels scatter i64idx u8 can implement).
pub fn baracuda_kernels_scatter_i64idx_u8_can_implement(
upd_numel: i64, rank: i32, scatter_dim: i32, out_dim_size: i32,
upd_shape: *const i32, stride_upd: *const i64,
stride_index: *const i64, stride_out: *const i64,
updates: *const c_void, index: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_scatter_i64idx_i8_run` (baracuda kernels scatter i64idx i8 run).
pub fn baracuda_kernels_scatter_i64idx_i8_run(
upd_numel: i64, rank: i32, scatter_dim: i32, out_dim_size: i32,
upd_shape: *const i32, stride_upd: *const i64,
stride_index: *const i64, stride_out: *const i64,
updates: *const c_void, index: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_scatter_i64idx_i8_can_implement` (baracuda kernels scatter i64idx i8 can implement).
pub fn baracuda_kernels_scatter_i64idx_i8_can_implement(
upd_numel: i64, rank: i32, scatter_dim: i32, out_dim_size: i32,
upd_shape: *const i32, stride_upd: *const i64,
stride_index: *const i64, stride_out: *const i64,
updates: *const c_void, index: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_scatter_i64idx_u16_run` (baracuda kernels scatter i64idx u16 run).
pub fn baracuda_kernels_scatter_i64idx_u16_run(
upd_numel: i64, rank: i32, scatter_dim: i32, out_dim_size: i32,
upd_shape: *const i32, stride_upd: *const i64,
stride_index: *const i64, stride_out: *const i64,
updates: *const c_void, index: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_scatter_i64idx_u16_can_implement` (baracuda kernels scatter i64idx u16 can implement).
pub fn baracuda_kernels_scatter_i64idx_u16_can_implement(
upd_numel: i64, rank: i32, scatter_dim: i32, out_dim_size: i32,
upd_shape: *const i32, stride_upd: *const i64,
stride_index: *const i64, stride_out: *const i64,
updates: *const c_void, index: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_scatter_i64idx_i16_run` (baracuda kernels scatter i64idx i16 run).
pub fn baracuda_kernels_scatter_i64idx_i16_run(
upd_numel: i64, rank: i32, scatter_dim: i32, out_dim_size: i32,
upd_shape: *const i32, stride_upd: *const i64,
stride_index: *const i64, stride_out: *const i64,
updates: *const c_void, index: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_scatter_i64idx_i16_can_implement` (baracuda kernels scatter i64idx i16 can implement).
pub fn baracuda_kernels_scatter_i64idx_i16_can_implement(
upd_numel: i64, rank: i32, scatter_dim: i32, out_dim_size: i32,
upd_shape: *const i32, stride_upd: *const i64,
stride_index: *const i64, stride_out: *const i64,
updates: *const c_void, index: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_scatter_i64idx_u32_run` (baracuda kernels scatter i64idx u32 run).
pub fn baracuda_kernels_scatter_i64idx_u32_run(
upd_numel: i64, rank: i32, scatter_dim: i32, out_dim_size: i32,
upd_shape: *const i32, stride_upd: *const i64,
stride_index: *const i64, stride_out: *const i64,
updates: *const c_void, index: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_scatter_i64idx_u32_can_implement` (baracuda kernels scatter i64idx u32 can implement).
pub fn baracuda_kernels_scatter_i64idx_u32_can_implement(
upd_numel: i64, rank: i32, scatter_dim: i32, out_dim_size: i32,
upd_shape: *const i32, stride_upd: *const i64,
stride_index: *const i64, stride_out: *const i64,
updates: *const c_void, index: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_scatter_i64idx_i32_run` (baracuda kernels scatter i64idx i32 run).
pub fn baracuda_kernels_scatter_i64idx_i32_run(
upd_numel: i64, rank: i32, scatter_dim: i32, out_dim_size: i32,
upd_shape: *const i32, stride_upd: *const i64,
stride_index: *const i64, stride_out: *const i64,
updates: *const c_void, index: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_scatter_i64idx_i32_can_implement` (baracuda kernels scatter i64idx i32 can implement).
pub fn baracuda_kernels_scatter_i64idx_i32_can_implement(
upd_numel: i64, rank: i32, scatter_dim: i32, out_dim_size: i32,
upd_shape: *const i32, stride_upd: *const i64,
stride_index: *const i64, stride_out: *const i64,
updates: *const c_void, index: *const c_void, out: *const c_void,
) -> i32;
/// `baracuda_kernels_scatter_i64idx_i64_run` (baracuda kernels scatter i64idx i64 run).
pub fn baracuda_kernels_scatter_i64idx_i64_run(
upd_numel: i64, rank: i32, scatter_dim: i32, out_dim_size: i32,
upd_shape: *const i32, stride_upd: *const i64,
stride_index: *const i64, stride_out: *const i64,
updates: *const c_void, index: *const c_void, out: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_scatter_i64idx_i64_can_implement` (baracuda kernels scatter i64idx i64 can implement).
pub fn baracuda_kernels_scatter_i64idx_i64_can_implement(
upd_numel: i64, rank: i32, scatter_dim: i32, out_dim_size: i32,
upd_shape: *const i32, stride_upd: *const i64,
stride_index: *const i64, stride_out: *const i64,
updates: *const c_void, index: *const c_void, out: *const c_void,
) -> i32;
// -- index_add (atomicAdd-Σ; native-atomic-only ints) --
/// `baracuda_kernels_index_add_i32_run` (baracuda kernels index add i32 run).
pub fn baracuda_kernels_index_add_i32_run(
src_numel: i64, rank: i32, add_dim: i32, dst_dim_size: i32,
src_shape: *const i32, stride_src: *const i64, stride_dst: *const i64,
src: *const c_void, idx: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_index_add_i32_can_implement` (baracuda kernels index add i32 can implement).
pub fn baracuda_kernels_index_add_i32_can_implement(
src_numel: i64, rank: i32, add_dim: i32, dst_dim_size: i32,
src_shape: *const i32, stride_src: *const i64, stride_dst: *const i64,
src: *const c_void, idx: *const c_void, dst: *const c_void,
) -> i32;
/// `baracuda_kernels_index_add_u32_run` (baracuda kernels index add u32 run).
pub fn baracuda_kernels_index_add_u32_run(
src_numel: i64, rank: i32, add_dim: i32, dst_dim_size: i32,
src_shape: *const i32, stride_src: *const i64, stride_dst: *const i64,
src: *const c_void, idx: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_index_add_u32_can_implement` (baracuda kernels index add u32 can implement).
pub fn baracuda_kernels_index_add_u32_can_implement(
src_numel: i64, rank: i32, add_dim: i32, dst_dim_size: i32,
src_shape: *const i32, stride_src: *const i64, stride_dst: *const i64,
src: *const c_void, idx: *const c_void, dst: *const c_void,
) -> i32;
/// `baracuda_kernels_index_add_i64_run` (baracuda kernels index add i64 run).
pub fn baracuda_kernels_index_add_i64_run(
src_numel: i64, rank: i32, add_dim: i32, dst_dim_size: i32,
src_shape: *const i32, stride_src: *const i64, stride_dst: *const i64,
src: *const c_void, idx: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_index_add_i64_can_implement` (baracuda kernels index add i64 can implement).
pub fn baracuda_kernels_index_add_i64_can_implement(
src_numel: i64, rank: i32, add_dim: i32, dst_dim_size: i32,
src_shape: *const i32, stride_src: *const i64, stride_dst: *const i64,
src: *const c_void, idx: *const c_void, dst: *const c_void,
) -> i32;
/// `baracuda_kernels_index_add_i64idx_i32_run` (baracuda kernels index add i64idx i32 run).
pub fn baracuda_kernels_index_add_i64idx_i32_run(
src_numel: i64, rank: i32, add_dim: i32, dst_dim_size: i32,
src_shape: *const i32, stride_src: *const i64, stride_dst: *const i64,
src: *const c_void, idx: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_index_add_i64idx_i32_can_implement` (baracuda kernels index add i64idx i32 can implement).
pub fn baracuda_kernels_index_add_i64idx_i32_can_implement(
src_numel: i64, rank: i32, add_dim: i32, dst_dim_size: i32,
src_shape: *const i32, stride_src: *const i64, stride_dst: *const i64,
src: *const c_void, idx: *const c_void, dst: *const c_void,
) -> i32;
/// `baracuda_kernels_index_add_i64idx_u32_run` (baracuda kernels index add i64idx u32 run).
pub fn baracuda_kernels_index_add_i64idx_u32_run(
src_numel: i64, rank: i32, add_dim: i32, dst_dim_size: i32,
src_shape: *const i32, stride_src: *const i64, stride_dst: *const i64,
src: *const c_void, idx: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_index_add_i64idx_u32_can_implement` (baracuda kernels index add i64idx u32 can implement).
pub fn baracuda_kernels_index_add_i64idx_u32_can_implement(
src_numel: i64, rank: i32, add_dim: i32, dst_dim_size: i32,
src_shape: *const i32, stride_src: *const i64, stride_dst: *const i64,
src: *const c_void, idx: *const c_void, dst: *const c_void,
) -> i32;
/// `baracuda_kernels_index_add_i64idx_i64_run` (baracuda kernels index add i64idx i64 run).
pub fn baracuda_kernels_index_add_i64idx_i64_run(
src_numel: i64, rank: i32, add_dim: i32, dst_dim_size: i32,
src_shape: *const i32, stride_src: *const i64, stride_dst: *const i64,
src: *const c_void, idx: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_index_add_i64idx_i64_can_implement` (baracuda kernels index add i64idx i64 can implement).
pub fn baracuda_kernels_index_add_i64idx_i64_can_implement(
src_numel: i64, rank: i32, add_dim: i32, dst_dim_size: i32,
src_shape: *const i32, stride_src: *const i64, stride_dst: *const i64,
src: *const c_void, idx: *const c_void, dst: *const c_void,
) -> i32;
// ---------- embedding (Phase 7 Milestone 7.5) ----------
//
// `out[n, :] = weight[indices[n], :]` with `padding_idx` zeroing
// rows where `indices[n] == padding_idx`. Caller passes
// `i32::MIN` as the padding-disabled sentinel (`kPaddingDisabled`
// in the .cuh). `weight: [V, D]` row-major contiguous; `indices`
// is a flat i32 buffer of length `num_indices`; `out: [N, D]`
// row-major contiguous.
/// `embedding` FW — f32 (pure copy).
pub fn baracuda_kernels_embedding_f32_run(
num_indices: i64,
num_embeddings: i32,
embedding_dim: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_f32`.
pub fn baracuda_kernels_embedding_f32_can_implement(
num_indices: i64,
num_embeddings: i32,
embedding_dim: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
out: *const c_void,
) -> i32;
/// `embedding` FW — f64.
pub fn baracuda_kernels_embedding_f64_run(
num_indices: i64,
num_embeddings: i32,
embedding_dim: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_f64`.
pub fn baracuda_kernels_embedding_f64_can_implement(
num_indices: i64,
num_embeddings: i32,
embedding_dim: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
out: *const c_void,
) -> i32;
/// `embedding` FW — f16.
pub fn baracuda_kernels_embedding_f16_run(
num_indices: i64,
num_embeddings: i32,
embedding_dim: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_f16`.
pub fn baracuda_kernels_embedding_f16_can_implement(
num_indices: i64,
num_embeddings: i32,
embedding_dim: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
out: *const c_void,
) -> i32;
/// `embedding` FW — bf16.
pub fn baracuda_kernels_embedding_bf16_run(
num_indices: i64,
num_embeddings: i32,
embedding_dim: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_bf16`.
pub fn baracuda_kernels_embedding_bf16_can_implement(
num_indices: i64,
num_embeddings: i32,
embedding_dim: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
out: *const c_void,
) -> i32;
/// `embedding` BW — `dweight[indices[n], :] += dout[n, :]` (atomicAdd),
/// skipping rows where `indices[n] == padding_idx`. f32.
pub fn baracuda_kernels_embedding_backward_f32_run(
num_indices: i64,
num_embeddings: i32,
embedding_dim: i32,
padding_idx: i64,
dout: *const c_void,
indices: *const c_void,
dweight: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_backward_f32`.
pub fn baracuda_kernels_embedding_backward_f32_can_implement(
num_indices: i64,
num_embeddings: i32,
embedding_dim: i32,
padding_idx: i64,
dout: *const c_void,
indices: *const c_void,
dweight: *const c_void,
) -> i32;
/// `embedding` BW — f64.
pub fn baracuda_kernels_embedding_backward_f64_run(
num_indices: i64,
num_embeddings: i32,
embedding_dim: i32,
padding_idx: i64,
dout: *const c_void,
indices: *const c_void,
dweight: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_backward_f64`.
pub fn baracuda_kernels_embedding_backward_f64_can_implement(
num_indices: i64,
num_embeddings: i32,
embedding_dim: i32,
padding_idx: i64,
dout: *const c_void,
indices: *const c_void,
dweight: *const c_void,
) -> i32;
// ---------- embedding_bag (Phase 7 Milestone 7.5) ----------
//
// `out[b, :] = reduce(weight[indices[k], :] for k in
// offsets[b]..offsets[b+1])`. `mode` is 0 (Sum) or 1 (Mean). The
// last bag's end is `total_indices`. `padding_idx` skips rows;
// for Mean the divisor uses the post-skip count.
/// `embedding_bag` FW — f32.
pub fn baracuda_kernels_embedding_bag_f32_run(
total_indices: i32,
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
mode: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
offsets: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_bag_f32`.
pub fn baracuda_kernels_embedding_bag_f32_can_implement(
total_indices: i32,
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
mode: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
offsets: *const c_void,
out: *const c_void,
) -> i32;
/// `embedding_bag` FW — f64.
pub fn baracuda_kernels_embedding_bag_f64_run(
total_indices: i32,
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
mode: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
offsets: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_bag_f64`.
pub fn baracuda_kernels_embedding_bag_f64_can_implement(
total_indices: i32,
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
mode: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
offsets: *const c_void,
out: *const c_void,
) -> i32;
/// `embedding_bag` FW — f16.
pub fn baracuda_kernels_embedding_bag_f16_run(
total_indices: i32,
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
mode: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
offsets: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_bag_f16`.
pub fn baracuda_kernels_embedding_bag_f16_can_implement(
total_indices: i32,
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
mode: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
offsets: *const c_void,
out: *const c_void,
) -> i32;
/// `embedding_bag` FW — bf16.
pub fn baracuda_kernels_embedding_bag_bf16_run(
total_indices: i32,
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
mode: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
offsets: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_bag_bf16`.
pub fn baracuda_kernels_embedding_bag_bf16_can_implement(
total_indices: i32,
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
mode: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
offsets: *const c_void,
out: *const c_void,
) -> i32;
/// `embedding_bag` BW — atomicAdd into `dweight`. f32.
pub fn baracuda_kernels_embedding_bag_backward_f32_run(
total_indices: i32,
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
mode: i32,
padding_idx: i64,
dout: *const c_void,
indices: *const c_void,
offsets: *const c_void,
dweight: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_bag_backward_f32`.
pub fn baracuda_kernels_embedding_bag_backward_f32_can_implement(
total_indices: i32,
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
mode: i32,
padding_idx: i64,
dout: *const c_void,
indices: *const c_void,
offsets: *const c_void,
dweight: *const c_void,
) -> i32;
/// `embedding_bag` BW — f64.
pub fn baracuda_kernels_embedding_bag_backward_f64_run(
total_indices: i32,
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
mode: i32,
padding_idx: i64,
dout: *const c_void,
indices: *const c_void,
offsets: *const c_void,
dweight: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_bag_backward_f64`.
pub fn baracuda_kernels_embedding_bag_backward_f64_can_implement(
total_indices: i32,
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
mode: i32,
padding_idx: i64,
dout: *const c_void,
indices: *const c_void,
offsets: *const c_void,
dweight: *const c_void,
) -> i32;
// ---------- i64-index variants (Phase 11.5 / Fuel team feedback #7) ----------
//
// PyTorch defaults `embedding` / `embedding_bag` indices to int64.
// The legacy entry points above keep their i32 ABI; these new
// `_i64idx_` symbols accept int64 index buffers directly.
//
// `padding_idx` carries int64 in both surfaces — i32 callers
// sign-extend their value (or `i32::MIN` sentinel) into the
// 64-bit slot on the way in.
/// `embedding` FW — f32, i64 indices.
pub fn baracuda_kernels_embedding_i64idx_f32_run(
num_indices: i64,
num_embeddings: i32,
embedding_dim: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_i64idx_f32`.
pub fn baracuda_kernels_embedding_i64idx_f32_can_implement(
num_indices: i64,
num_embeddings: i32,
embedding_dim: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
out: *const c_void,
) -> i32;
/// `embedding` FW — f64, i64 indices.
pub fn baracuda_kernels_embedding_i64idx_f64_run(
num_indices: i64,
num_embeddings: i32,
embedding_dim: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_i64idx_f64`.
pub fn baracuda_kernels_embedding_i64idx_f64_can_implement(
num_indices: i64,
num_embeddings: i32,
embedding_dim: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
out: *const c_void,
) -> i32;
/// `embedding` FW — f16, i64 indices.
pub fn baracuda_kernels_embedding_i64idx_f16_run(
num_indices: i64,
num_embeddings: i32,
embedding_dim: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_i64idx_f16`.
pub fn baracuda_kernels_embedding_i64idx_f16_can_implement(
num_indices: i64,
num_embeddings: i32,
embedding_dim: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
out: *const c_void,
) -> i32;
/// `embedding` FW — bf16, i64 indices.
pub fn baracuda_kernels_embedding_i64idx_bf16_run(
num_indices: i64,
num_embeddings: i32,
embedding_dim: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_i64idx_bf16`.
pub fn baracuda_kernels_embedding_i64idx_bf16_can_implement(
num_indices: i64,
num_embeddings: i32,
embedding_dim: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
out: *const c_void,
) -> i32;
/// `embedding` BW — f32, i64 indices.
pub fn baracuda_kernels_embedding_backward_i64idx_f32_run(
num_indices: i64,
num_embeddings: i32,
embedding_dim: i32,
padding_idx: i64,
dout: *const c_void,
indices: *const c_void,
dweight: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_backward_i64idx_f32`.
pub fn baracuda_kernels_embedding_backward_i64idx_f32_can_implement(
num_indices: i64,
num_embeddings: i32,
embedding_dim: i32,
padding_idx: i64,
dout: *const c_void,
indices: *const c_void,
dweight: *const c_void,
) -> i32;
/// `embedding` BW — f64, i64 indices.
pub fn baracuda_kernels_embedding_backward_i64idx_f64_run(
num_indices: i64,
num_embeddings: i32,
embedding_dim: i32,
padding_idx: i64,
dout: *const c_void,
indices: *const c_void,
dweight: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_backward_i64idx_f64`.
pub fn baracuda_kernels_embedding_backward_i64idx_f64_can_implement(
num_indices: i64,
num_embeddings: i32,
embedding_dim: i32,
padding_idx: i64,
dout: *const c_void,
indices: *const c_void,
dweight: *const c_void,
) -> i32;
/// `embedding_bag` FW — f32, i64 indices.
pub fn baracuda_kernels_embedding_bag_i64idx_f32_run(
total_indices: i32,
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
mode: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
offsets: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_bag_i64idx_f32`.
pub fn baracuda_kernels_embedding_bag_i64idx_f32_can_implement(
total_indices: i32,
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
mode: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
offsets: *const c_void,
out: *const c_void,
) -> i32;
/// `embedding_bag` FW — f64, i64 indices.
pub fn baracuda_kernels_embedding_bag_i64idx_f64_run(
total_indices: i32,
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
mode: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
offsets: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_bag_i64idx_f64`.
pub fn baracuda_kernels_embedding_bag_i64idx_f64_can_implement(
total_indices: i32,
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
mode: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
offsets: *const c_void,
out: *const c_void,
) -> i32;
/// `embedding_bag` FW — f16, i64 indices.
pub fn baracuda_kernels_embedding_bag_i64idx_f16_run(
total_indices: i32,
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
mode: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
offsets: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_bag_i64idx_f16`.
pub fn baracuda_kernels_embedding_bag_i64idx_f16_can_implement(
total_indices: i32,
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
mode: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
offsets: *const c_void,
out: *const c_void,
) -> i32;
/// `embedding_bag` FW — bf16, i64 indices.
pub fn baracuda_kernels_embedding_bag_i64idx_bf16_run(
total_indices: i32,
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
mode: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
offsets: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_bag_i64idx_bf16`.
pub fn baracuda_kernels_embedding_bag_i64idx_bf16_can_implement(
total_indices: i32,
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
mode: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
offsets: *const c_void,
out: *const c_void,
) -> i32;
/// `embedding_bag` BW — f32, i64 indices.
pub fn baracuda_kernels_embedding_bag_backward_i64idx_f32_run(
total_indices: i32,
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
mode: i32,
padding_idx: i64,
dout: *const c_void,
indices: *const c_void,
offsets: *const c_void,
dweight: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_bag_backward_i64idx_f32`.
pub fn baracuda_kernels_embedding_bag_backward_i64idx_f32_can_implement(
total_indices: i32,
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
mode: i32,
padding_idx: i64,
dout: *const c_void,
indices: *const c_void,
offsets: *const c_void,
dweight: *const c_void,
) -> i32;
/// `embedding_bag` BW — f64, i64 indices.
pub fn baracuda_kernels_embedding_bag_backward_i64idx_f64_run(
total_indices: i32,
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
mode: i32,
padding_idx: i64,
dout: *const c_void,
indices: *const c_void,
offsets: *const c_void,
dweight: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_bag_backward_i64idx_f64`.
pub fn baracuda_kernels_embedding_bag_backward_i64idx_f64_can_implement(
total_indices: i32,
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
mode: i32,
padding_idx: i64,
dout: *const c_void,
indices: *const c_void,
offsets: *const c_void,
dweight: *const c_void,
) -> i32;
// ---------- Phase 25: embedding_bag Max mode (FW + BW) ----------
//
// FW writes value + per-(b, d) contributing row index `out_index`
// (i32). BW scatters dout into dweight at those rows via atomicAdd
// — value dtype is generic (T), index dtype is fixed at i32.
//
// Tie-break = first occurrence (lowest k in the bag). PyTorch
// chooses last; documented divergence.
/// `embedding_bag` Max-mode FW — f32 (i32 indices).
pub fn baracuda_kernels_embedding_bag_max_f32_run(
total_indices: i32,
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
offsets: *const c_void,
out: *mut c_void,
out_index: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_bag_max_f32`.
pub fn baracuda_kernels_embedding_bag_max_f32_can_implement(
total_indices: i32,
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
offsets: *const c_void,
out: *const c_void,
out_index: *const c_void,
) -> i32;
/// `embedding_bag_max` FW — f64.
pub fn baracuda_kernels_embedding_bag_max_f64_run(
total_indices: i32,
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
offsets: *const c_void,
out: *mut c_void,
out_index: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_bag_max_f64`.
pub fn baracuda_kernels_embedding_bag_max_f64_can_implement(
total_indices: i32,
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
padding_idx: i64,
weight: *const c_void,
indices: *const c_void,
offsets: *const c_void,
out: *const c_void,
out_index: *const c_void,
) -> i32;
/// `embedding_bag_max` FW — f16.
pub fn baracuda_kernels_embedding_bag_max_f16_run(
total_indices: i32, num_embeddings: i32, embedding_dim: i32,
num_bags: i32, padding_idx: i64,
weight: *const c_void, indices: *const c_void, offsets: *const c_void,
out: *mut c_void, out_index: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_bag_max_f16`.
pub fn baracuda_kernels_embedding_bag_max_f16_can_implement(
total_indices: i32, num_embeddings: i32, embedding_dim: i32,
num_bags: i32, padding_idx: i64,
weight: *const c_void, indices: *const c_void, offsets: *const c_void,
out: *const c_void, out_index: *const c_void,
) -> i32;
/// `embedding_bag_max` FW — bf16.
pub fn baracuda_kernels_embedding_bag_max_bf16_run(
total_indices: i32, num_embeddings: i32, embedding_dim: i32,
num_bags: i32, padding_idx: i64,
weight: *const c_void, indices: *const c_void, offsets: *const c_void,
out: *mut c_void, out_index: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_bag_max_bf16`.
pub fn baracuda_kernels_embedding_bag_max_bf16_can_implement(
total_indices: i32, num_embeddings: i32, embedding_dim: i32,
num_bags: i32, padding_idx: i64,
weight: *const c_void, indices: *const c_void, offsets: *const c_void,
out: *const c_void, out_index: *const c_void,
) -> i32;
/// `embedding_bag_max` FW — f32, i64 indices.
pub fn baracuda_kernels_embedding_bag_max_i64idx_f32_run(
total_indices: i32, num_embeddings: i32, embedding_dim: i32,
num_bags: i32, padding_idx: i64,
weight: *const c_void, indices: *const c_void, offsets: *const c_void,
out: *mut c_void, out_index: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_bag_max_i64idx_f32`.
pub fn baracuda_kernels_embedding_bag_max_i64idx_f32_can_implement(
total_indices: i32, num_embeddings: i32, embedding_dim: i32,
num_bags: i32, padding_idx: i64,
weight: *const c_void, indices: *const c_void, offsets: *const c_void,
out: *const c_void, out_index: *const c_void,
) -> i32;
/// `embedding_bag_max` FW — f64, i64 indices.
pub fn baracuda_kernels_embedding_bag_max_i64idx_f64_run(
total_indices: i32, num_embeddings: i32, embedding_dim: i32,
num_bags: i32, padding_idx: i64,
weight: *const c_void, indices: *const c_void, offsets: *const c_void,
out: *mut c_void, out_index: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_bag_max_i64idx_f64`.
pub fn baracuda_kernels_embedding_bag_max_i64idx_f64_can_implement(
total_indices: i32, num_embeddings: i32, embedding_dim: i32,
num_bags: i32, padding_idx: i64,
weight: *const c_void, indices: *const c_void, offsets: *const c_void,
out: *const c_void, out_index: *const c_void,
) -> i32;
/// `embedding_bag_max` FW — f16, i64 indices.
pub fn baracuda_kernels_embedding_bag_max_i64idx_f16_run(
total_indices: i32, num_embeddings: i32, embedding_dim: i32,
num_bags: i32, padding_idx: i64,
weight: *const c_void, indices: *const c_void, offsets: *const c_void,
out: *mut c_void, out_index: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_bag_max_i64idx_f16`.
pub fn baracuda_kernels_embedding_bag_max_i64idx_f16_can_implement(
total_indices: i32, num_embeddings: i32, embedding_dim: i32,
num_bags: i32, padding_idx: i64,
weight: *const c_void, indices: *const c_void, offsets: *const c_void,
out: *const c_void, out_index: *const c_void,
) -> i32;
/// `embedding_bag_max` FW — bf16, i64 indices.
pub fn baracuda_kernels_embedding_bag_max_i64idx_bf16_run(
total_indices: i32, num_embeddings: i32, embedding_dim: i32,
num_bags: i32, padding_idx: i64,
weight: *const c_void, indices: *const c_void, offsets: *const c_void,
out: *mut c_void, out_index: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_bag_max_i64idx_bf16`.
pub fn baracuda_kernels_embedding_bag_max_i64idx_bf16_can_implement(
total_indices: i32, num_embeddings: i32, embedding_dim: i32,
num_bags: i32, padding_idx: i64,
weight: *const c_void, indices: *const c_void, offsets: *const c_void,
out: *const c_void, out_index: *const c_void,
) -> i32;
/// `embedding_bag_max` BW — f32. Index dtype is fixed at i32
/// (set by the FW's `out_index` output).
pub fn baracuda_kernels_embedding_bag_max_backward_f32_run(
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
dout: *const c_void,
out_index: *const c_void,
dweight: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_bag_max_backward_f32`.
pub fn baracuda_kernels_embedding_bag_max_backward_f32_can_implement(
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
dout: *const c_void,
out_index: *const c_void,
dweight: *const c_void,
) -> i32;
/// `embedding_bag_max` BW — f64.
pub fn baracuda_kernels_embedding_bag_max_backward_f64_run(
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
dout: *const c_void,
out_index: *const c_void,
dweight: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `embedding_bag_max_backward_f64`.
pub fn baracuda_kernels_embedding_bag_max_backward_f64_can_implement(
num_embeddings: i32,
embedding_dim: i32,
num_bags: i32,
dout: *const c_void,
out_index: *const c_void,
dweight: *const c_void,
) -> i32;
}
// ============================================================================
// Phase 7 Milestone 7.6 — Segment / scatter-reduce (Category S)
// ============================================================================
//
// Ten FW ops + four BW ops:
// Sorted FW: segment_sum, segment_mean, segment_max, segment_min,
// segment_prod (segment_ids monotonically non-decreasing)
// Unsorted FW: unsorted_segment_sum, unsorted_segment_mean,
// unsorted_segment_max, unsorted_segment_min
// (atomic scatter; unsorted_prod deferred — no native FP
// atomicMul)
// BW: segment_sum_backward, segment_mean_backward,
// unsorted_segment_sum_backward, unsorted_segment_mean_backward
// (sorted + unsorted share the BW launcher — the gather
// access pattern is identical)
//
// All FFI signatures share the shape:
// (N, D, num_segments, input/dout, segment_ids, output/dinput,
// workspace, workspace_bytes, stream)
//
// Workspace usage:
// - Sorted FW (sum/mean/max/min/prod): none (workspace_bytes ignored).
// - Unsorted FW sum/max/min: none.
// - Unsorted FW mean: requires `num_segments * sizeof(i32)` workspace
// for per-segment counts.
// - Sum BW: none.
// - Mean BW (sorted or unsorted): requires `num_segments * sizeof(i32)`
// workspace for per-segment counts.
//
// Dtype coverage: f32, f64 (atomic-FP-restricted). f16 / bf16 deferred.
// Out-of-range segment IDs are silently dropped.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
// ---------- Sorted FW ----------
/// `out[s, d] = Σ_{n : seg[n] == s} input[n, d]` — sorted seg ids
/// (monotonically non-decreasing). f32.
pub fn baracuda_kernels_segment_sum_f32_run(
n: i32,
d: i32,
num_segments: i32,
input: *const c_void,
segment_ids: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `segment_sum_f32`.
pub fn baracuda_kernels_segment_sum_f32_can_implement(
n: i32, d: i32, num_segments: i32,
input: *const c_void, segment_ids: *const c_void, output: *const c_void,
) -> i32;
/// `segment_sum` — f64.
pub fn baracuda_kernels_segment_sum_f64_run(
n: i32,
d: i32,
num_segments: i32,
input: *const c_void,
segment_ids: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `segment_sum_f64`.
pub fn baracuda_kernels_segment_sum_f64_can_implement(
n: i32, d: i32, num_segments: i32,
input: *const c_void, segment_ids: *const c_void, output: *const c_void,
) -> i32;
/// `out[s, d] = mean_{n : seg[n] == s} input[n, d]` — sorted. f32.
pub fn baracuda_kernels_segment_mean_f32_run(
n: i32,
d: i32,
num_segments: i32,
input: *const c_void,
segment_ids: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `segment_mean_f32`.
pub fn baracuda_kernels_segment_mean_f32_can_implement(
n: i32, d: i32, num_segments: i32,
input: *const c_void, segment_ids: *const c_void, output: *const c_void,
) -> i32;
/// `segment_mean` — f64.
pub fn baracuda_kernels_segment_mean_f64_run(
n: i32,
d: i32,
num_segments: i32,
input: *const c_void,
segment_ids: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `segment_mean_f64`.
pub fn baracuda_kernels_segment_mean_f64_can_implement(
n: i32, d: i32, num_segments: i32,
input: *const c_void, segment_ids: *const c_void, output: *const c_void,
) -> i32;
/// `out[s, d] = max_{n : seg[n] == s} input[n, d]` — sorted. f32.
pub fn baracuda_kernels_segment_max_f32_run(
n: i32,
d: i32,
num_segments: i32,
input: *const c_void,
segment_ids: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `segment_max_f32`.
pub fn baracuda_kernels_segment_max_f32_can_implement(
n: i32, d: i32, num_segments: i32,
input: *const c_void, segment_ids: *const c_void, output: *const c_void,
) -> i32;
/// `segment_max` — f64.
pub fn baracuda_kernels_segment_max_f64_run(
n: i32,
d: i32,
num_segments: i32,
input: *const c_void,
segment_ids: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `segment_max_f64`.
pub fn baracuda_kernels_segment_max_f64_can_implement(
n: i32, d: i32, num_segments: i32,
input: *const c_void, segment_ids: *const c_void, output: *const c_void,
) -> i32;
/// `out[s, d] = min_{n : seg[n] == s} input[n, d]` — sorted. f32.
pub fn baracuda_kernels_segment_min_f32_run(
n: i32,
d: i32,
num_segments: i32,
input: *const c_void,
segment_ids: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `segment_min_f32`.
pub fn baracuda_kernels_segment_min_f32_can_implement(
n: i32, d: i32, num_segments: i32,
input: *const c_void, segment_ids: *const c_void, output: *const c_void,
) -> i32;
/// `segment_min` — f64.
pub fn baracuda_kernels_segment_min_f64_run(
n: i32,
d: i32,
num_segments: i32,
input: *const c_void,
segment_ids: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `segment_min_f64`.
pub fn baracuda_kernels_segment_min_f64_can_implement(
n: i32, d: i32, num_segments: i32,
input: *const c_void, segment_ids: *const c_void, output: *const c_void,
) -> i32;
/// `out[s, d] = prod_{n : seg[n] == s} input[n, d]` — sorted. f32.
pub fn baracuda_kernels_segment_prod_f32_run(
n: i32,
d: i32,
num_segments: i32,
input: *const c_void,
segment_ids: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `segment_prod_f32`.
pub fn baracuda_kernels_segment_prod_f32_can_implement(
n: i32, d: i32, num_segments: i32,
input: *const c_void, segment_ids: *const c_void, output: *const c_void,
) -> i32;
/// `segment_prod` — f64.
pub fn baracuda_kernels_segment_prod_f64_run(
n: i32,
d: i32,
num_segments: i32,
input: *const c_void,
segment_ids: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `segment_prod_f64`.
pub fn baracuda_kernels_segment_prod_f64_can_implement(
n: i32, d: i32, num_segments: i32,
input: *const c_void, segment_ids: *const c_void, output: *const c_void,
) -> i32;
// ---------- Unsorted FW ----------
/// `out[s, d] = Σ_{n : seg[n] == s} input[n, d]` — unsorted seg
/// ids; atomicAdd into output. Output pre-zeroed by the launcher.
/// f32.
pub fn baracuda_kernels_unsorted_segment_sum_f32_run(
n: i32,
d: i32,
num_segments: i32,
input: *const c_void,
segment_ids: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `unsorted_segment_sum_f32`.
pub fn baracuda_kernels_unsorted_segment_sum_f32_can_implement(
n: i32, d: i32, num_segments: i32,
input: *const c_void, segment_ids: *const c_void, output: *const c_void,
) -> i32;
/// `unsorted_segment_sum` — f64.
pub fn baracuda_kernels_unsorted_segment_sum_f64_run(
n: i32,
d: i32,
num_segments: i32,
input: *const c_void,
segment_ids: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `unsorted_segment_sum_f64`.
pub fn baracuda_kernels_unsorted_segment_sum_f64_can_implement(
n: i32, d: i32, num_segments: i32,
input: *const c_void, segment_ids: *const c_void, output: *const c_void,
) -> i32;
/// `out[s, d] = mean_{n : seg[n] == s} input[n, d]` — unsorted.
/// Workspace: `num_segments * sizeof(i32)` for per-segment counts.
/// f32.
pub fn baracuda_kernels_unsorted_segment_mean_f32_run(
n: i32,
d: i32,
num_segments: i32,
input: *const c_void,
segment_ids: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `unsorted_segment_mean_f32`.
pub fn baracuda_kernels_unsorted_segment_mean_f32_can_implement(
n: i32, d: i32, num_segments: i32,
input: *const c_void, segment_ids: *const c_void, output: *const c_void,
) -> i32;
/// `unsorted_segment_mean` — f64.
pub fn baracuda_kernels_unsorted_segment_mean_f64_run(
n: i32,
d: i32,
num_segments: i32,
input: *const c_void,
segment_ids: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `unsorted_segment_mean_f64`.
pub fn baracuda_kernels_unsorted_segment_mean_f64_can_implement(
n: i32, d: i32, num_segments: i32,
input: *const c_void, segment_ids: *const c_void, output: *const c_void,
) -> i32;
/// `out[s, d] = max_{n : seg[n] == s} input[n, d]` — unsorted;
/// atomicMax-via-CAS. Output pre-initialized to `-inf` by the
/// launcher. f32.
pub fn baracuda_kernels_unsorted_segment_max_f32_run(
n: i32,
d: i32,
num_segments: i32,
input: *const c_void,
segment_ids: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `unsorted_segment_max_f32`.
pub fn baracuda_kernels_unsorted_segment_max_f32_can_implement(
n: i32, d: i32, num_segments: i32,
input: *const c_void, segment_ids: *const c_void, output: *const c_void,
) -> i32;
/// `unsorted_segment_max` — f64.
pub fn baracuda_kernels_unsorted_segment_max_f64_run(
n: i32,
d: i32,
num_segments: i32,
input: *const c_void,
segment_ids: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `unsorted_segment_max_f64`.
pub fn baracuda_kernels_unsorted_segment_max_f64_can_implement(
n: i32, d: i32, num_segments: i32,
input: *const c_void, segment_ids: *const c_void, output: *const c_void,
) -> i32;
/// `out[s, d] = min_{n : seg[n] == s} input[n, d]` — unsorted. f32.
pub fn baracuda_kernels_unsorted_segment_min_f32_run(
n: i32,
d: i32,
num_segments: i32,
input: *const c_void,
segment_ids: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `unsorted_segment_min_f32`.
pub fn baracuda_kernels_unsorted_segment_min_f32_can_implement(
n: i32, d: i32, num_segments: i32,
input: *const c_void, segment_ids: *const c_void, output: *const c_void,
) -> i32;
/// `unsorted_segment_min` — f64.
pub fn baracuda_kernels_unsorted_segment_min_f64_run(
n: i32,
d: i32,
num_segments: i32,
input: *const c_void,
segment_ids: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `unsorted_segment_min_f64`.
pub fn baracuda_kernels_unsorted_segment_min_f64_can_implement(
n: i32, d: i32, num_segments: i32,
input: *const c_void, segment_ids: *const c_void, output: *const c_void,
) -> i32;
// ---------- BW (sum / mean) ----------
//
// Sorted and unsorted variants share the same BW kernel — the
// access pattern (`d_input[n, d] = d_output[seg[n], d]` for sum;
// same divided by count for mean) is identical regardless of seg-
// ids ordering.
/// `d_input[n, d] = d_output[seg[n], d]`. f32.
pub fn baracuda_kernels_segment_sum_backward_f32_run(
n: i32,
d: i32,
num_segments: i32,
d_output: *const c_void,
segment_ids: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `segment_sum_backward_f32`.
pub fn baracuda_kernels_segment_sum_backward_f32_can_implement(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void, segment_ids: *const c_void, d_input: *const c_void,
) -> i32;
/// `segment_sum_backward` — f64.
pub fn baracuda_kernels_segment_sum_backward_f64_run(
n: i32,
d: i32,
num_segments: i32,
d_output: *const c_void,
segment_ids: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `segment_sum_backward_f64`.
pub fn baracuda_kernels_segment_sum_backward_f64_can_implement(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void, segment_ids: *const c_void, d_input: *const c_void,
) -> i32;
/// Same kernel as `segment_sum_backward_f32`; distinct symbol for
/// SKU-tagging differentiation.
pub fn baracuda_kernels_unsorted_segment_sum_backward_f32_run(
n: i32,
d: i32,
num_segments: i32,
d_output: *const c_void,
segment_ids: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `unsorted_segment_sum_backward_f32`.
pub fn baracuda_kernels_unsorted_segment_sum_backward_f32_can_implement(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void, segment_ids: *const c_void, d_input: *const c_void,
) -> i32;
/// `unsorted_segment_sum_backward` — f64.
pub fn baracuda_kernels_unsorted_segment_sum_backward_f64_run(
n: i32,
d: i32,
num_segments: i32,
d_output: *const c_void,
segment_ids: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `unsorted_segment_sum_backward_f64`.
pub fn baracuda_kernels_unsorted_segment_sum_backward_f64_can_implement(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void, segment_ids: *const c_void, d_input: *const c_void,
) -> i32;
/// `d_input[n, d] = d_output[seg[n], d] / count[seg[n]]`.
/// Workspace: `num_segments * sizeof(i32)`. f32.
pub fn baracuda_kernels_segment_mean_backward_f32_run(
n: i32,
d: i32,
num_segments: i32,
d_output: *const c_void,
segment_ids: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `segment_mean_backward_f32`.
pub fn baracuda_kernels_segment_mean_backward_f32_can_implement(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void, segment_ids: *const c_void, d_input: *const c_void,
) -> i32;
/// `segment_mean_backward` — f64.
pub fn baracuda_kernels_segment_mean_backward_f64_run(
n: i32,
d: i32,
num_segments: i32,
d_output: *const c_void,
segment_ids: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `segment_mean_backward_f64`.
pub fn baracuda_kernels_segment_mean_backward_f64_can_implement(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void, segment_ids: *const c_void, d_input: *const c_void,
) -> i32;
/// `unsorted_segment_mean_backward` — f32.
pub fn baracuda_kernels_unsorted_segment_mean_backward_f32_run(
n: i32,
d: i32,
num_segments: i32,
d_output: *const c_void,
segment_ids: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `unsorted_segment_mean_backward_f32`.
pub fn baracuda_kernels_unsorted_segment_mean_backward_f32_can_implement(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void, segment_ids: *const c_void, d_input: *const c_void,
) -> i32;
/// `unsorted_segment_mean_backward` — f64.
pub fn baracuda_kernels_unsorted_segment_mean_backward_f64_run(
n: i32,
d: i32,
num_segments: i32,
d_output: *const c_void,
segment_ids: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `unsorted_segment_mean_backward_f64`.
pub fn baracuda_kernels_unsorted_segment_mean_backward_f64_can_implement(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void, segment_ids: *const c_void, d_input: *const c_void,
) -> i32;
// ---------- segment / scatter-reduce i64 variants (Phase 11.5) ----------
//
// PyTorch / TensorFlow / JAX scatter-reduce segment ids default to
// int64. The `_i64idx_` symbols below mirror the i32 surface,
// differing only in the dereferenced type of the `segment_ids`
// buffer.
/// `baracuda_kernels_segment_sum_i64idx_f32_run` (baracuda kernels segment sum i64idx f32 run).
pub fn baracuda_kernels_segment_sum_i64idx_f32_run(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_segment_sum_i64idx_f32_can_implement` (baracuda kernels segment sum i64idx f32 can implement).
pub fn baracuda_kernels_segment_sum_i64idx_f32_can_implement(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *const c_void) -> i32;
/// `baracuda_kernels_segment_sum_i64idx_f64_run` (baracuda kernels segment sum i64idx f64 run).
pub fn baracuda_kernels_segment_sum_i64idx_f64_run(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_segment_sum_i64idx_f64_can_implement` (baracuda kernels segment sum i64idx f64 can implement).
pub fn baracuda_kernels_segment_sum_i64idx_f64_can_implement(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *const c_void) -> i32;
/// `baracuda_kernels_segment_mean_i64idx_f32_run` (baracuda kernels segment mean i64idx f32 run).
pub fn baracuda_kernels_segment_mean_i64idx_f32_run(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_segment_mean_i64idx_f32_can_implement` (baracuda kernels segment mean i64idx f32 can implement).
pub fn baracuda_kernels_segment_mean_i64idx_f32_can_implement(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *const c_void) -> i32;
/// `baracuda_kernels_segment_mean_i64idx_f64_run` (baracuda kernels segment mean i64idx f64 run).
pub fn baracuda_kernels_segment_mean_i64idx_f64_run(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_segment_mean_i64idx_f64_can_implement` (baracuda kernels segment mean i64idx f64 can implement).
pub fn baracuda_kernels_segment_mean_i64idx_f64_can_implement(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *const c_void) -> i32;
/// `baracuda_kernels_segment_max_i64idx_f32_run` (baracuda kernels segment max i64idx f32 run).
pub fn baracuda_kernels_segment_max_i64idx_f32_run(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_segment_max_i64idx_f32_can_implement` (baracuda kernels segment max i64idx f32 can implement).
pub fn baracuda_kernels_segment_max_i64idx_f32_can_implement(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *const c_void) -> i32;
/// `baracuda_kernels_segment_max_i64idx_f64_run` (baracuda kernels segment max i64idx f64 run).
pub fn baracuda_kernels_segment_max_i64idx_f64_run(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_segment_max_i64idx_f64_can_implement` (baracuda kernels segment max i64idx f64 can implement).
pub fn baracuda_kernels_segment_max_i64idx_f64_can_implement(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *const c_void) -> i32;
/// `baracuda_kernels_segment_min_i64idx_f32_run` (baracuda kernels segment min i64idx f32 run).
pub fn baracuda_kernels_segment_min_i64idx_f32_run(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_segment_min_i64idx_f32_can_implement` (baracuda kernels segment min i64idx f32 can implement).
pub fn baracuda_kernels_segment_min_i64idx_f32_can_implement(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *const c_void) -> i32;
/// `baracuda_kernels_segment_min_i64idx_f64_run` (baracuda kernels segment min i64idx f64 run).
pub fn baracuda_kernels_segment_min_i64idx_f64_run(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_segment_min_i64idx_f64_can_implement` (baracuda kernels segment min i64idx f64 can implement).
pub fn baracuda_kernels_segment_min_i64idx_f64_can_implement(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *const c_void) -> i32;
/// `baracuda_kernels_segment_prod_i64idx_f32_run` (baracuda kernels segment prod i64idx f32 run).
pub fn baracuda_kernels_segment_prod_i64idx_f32_run(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_segment_prod_i64idx_f32_can_implement` (baracuda kernels segment prod i64idx f32 can implement).
pub fn baracuda_kernels_segment_prod_i64idx_f32_can_implement(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *const c_void) -> i32;
/// `baracuda_kernels_segment_prod_i64idx_f64_run` (baracuda kernels segment prod i64idx f64 run).
pub fn baracuda_kernels_segment_prod_i64idx_f64_run(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_segment_prod_i64idx_f64_can_implement` (baracuda kernels segment prod i64idx f64 can implement).
pub fn baracuda_kernels_segment_prod_i64idx_f64_can_implement(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *const c_void) -> i32;
/// `baracuda_kernels_unsorted_segment_sum_i64idx_f32_run` (baracuda kernels unsorted segment sum i64idx f32 run).
pub fn baracuda_kernels_unsorted_segment_sum_i64idx_f32_run(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_unsorted_segment_sum_i64idx_f32_can_implement` (baracuda kernels unsorted segment sum i64idx f32 can implement).
pub fn baracuda_kernels_unsorted_segment_sum_i64idx_f32_can_implement(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *const c_void) -> i32;
/// `baracuda_kernels_unsorted_segment_sum_i64idx_f64_run` (baracuda kernels unsorted segment sum i64idx f64 run).
pub fn baracuda_kernels_unsorted_segment_sum_i64idx_f64_run(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_unsorted_segment_sum_i64idx_f64_can_implement` (baracuda kernels unsorted segment sum i64idx f64 can implement).
pub fn baracuda_kernels_unsorted_segment_sum_i64idx_f64_can_implement(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *const c_void) -> i32;
/// `baracuda_kernels_unsorted_segment_mean_i64idx_f32_run` (baracuda kernels unsorted segment mean i64idx f32 run).
pub fn baracuda_kernels_unsorted_segment_mean_i64idx_f32_run(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_unsorted_segment_mean_i64idx_f32_can_implement` (baracuda kernels unsorted segment mean i64idx f32 can implement).
pub fn baracuda_kernels_unsorted_segment_mean_i64idx_f32_can_implement(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *const c_void) -> i32;
/// `baracuda_kernels_unsorted_segment_mean_i64idx_f64_run` (baracuda kernels unsorted segment mean i64idx f64 run).
pub fn baracuda_kernels_unsorted_segment_mean_i64idx_f64_run(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_unsorted_segment_mean_i64idx_f64_can_implement` (baracuda kernels unsorted segment mean i64idx f64 can implement).
pub fn baracuda_kernels_unsorted_segment_mean_i64idx_f64_can_implement(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *const c_void) -> i32;
/// `baracuda_kernels_unsorted_segment_max_i64idx_f32_run` (baracuda kernels unsorted segment max i64idx f32 run).
pub fn baracuda_kernels_unsorted_segment_max_i64idx_f32_run(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_unsorted_segment_max_i64idx_f32_can_implement` (baracuda kernels unsorted segment max i64idx f32 can implement).
pub fn baracuda_kernels_unsorted_segment_max_i64idx_f32_can_implement(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *const c_void) -> i32;
/// `baracuda_kernels_unsorted_segment_max_i64idx_f64_run` (baracuda kernels unsorted segment max i64idx f64 run).
pub fn baracuda_kernels_unsorted_segment_max_i64idx_f64_run(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_unsorted_segment_max_i64idx_f64_can_implement` (baracuda kernels unsorted segment max i64idx f64 can implement).
pub fn baracuda_kernels_unsorted_segment_max_i64idx_f64_can_implement(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *const c_void) -> i32;
/// `baracuda_kernels_unsorted_segment_min_i64idx_f32_run` (baracuda kernels unsorted segment min i64idx f32 run).
pub fn baracuda_kernels_unsorted_segment_min_i64idx_f32_run(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_unsorted_segment_min_i64idx_f32_can_implement` (baracuda kernels unsorted segment min i64idx f32 can implement).
pub fn baracuda_kernels_unsorted_segment_min_i64idx_f32_can_implement(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *const c_void) -> i32;
/// `baracuda_kernels_unsorted_segment_min_i64idx_f64_run` (baracuda kernels unsorted segment min i64idx f64 run).
pub fn baracuda_kernels_unsorted_segment_min_i64idx_f64_run(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_unsorted_segment_min_i64idx_f64_can_implement` (baracuda kernels unsorted segment min i64idx f64 can implement).
pub fn baracuda_kernels_unsorted_segment_min_i64idx_f64_can_implement(n: i32, d: i32, num_segments: i32, input: *const c_void, segment_ids: *const c_void, output: *const c_void) -> i32;
/// `baracuda_kernels_segment_sum_backward_i64idx_f32_run` (baracuda kernels segment sum backward i64idx f32 run).
pub fn baracuda_kernels_segment_sum_backward_i64idx_f32_run(n: i32, d: i32, num_segments: i32, d_output: *const c_void, segment_ids: *const c_void, d_input: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_segment_sum_backward_i64idx_f32_can_implement` (baracuda kernels segment sum backward i64idx f32 can implement).
pub fn baracuda_kernels_segment_sum_backward_i64idx_f32_can_implement(n: i32, d: i32, num_segments: i32, d_output: *const c_void, segment_ids: *const c_void, d_input: *const c_void) -> i32;
/// `baracuda_kernels_segment_sum_backward_i64idx_f64_run` (baracuda kernels segment sum backward i64idx f64 run).
pub fn baracuda_kernels_segment_sum_backward_i64idx_f64_run(n: i32, d: i32, num_segments: i32, d_output: *const c_void, segment_ids: *const c_void, d_input: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_segment_sum_backward_i64idx_f64_can_implement` (baracuda kernels segment sum backward i64idx f64 can implement).
pub fn baracuda_kernels_segment_sum_backward_i64idx_f64_can_implement(n: i32, d: i32, num_segments: i32, d_output: *const c_void, segment_ids: *const c_void, d_input: *const c_void) -> i32;
/// `baracuda_kernels_unsorted_segment_sum_backward_i64idx_f32_run` (baracuda kernels unsorted segment sum backward i64idx f32 run).
pub fn baracuda_kernels_unsorted_segment_sum_backward_i64idx_f32_run(n: i32, d: i32, num_segments: i32, d_output: *const c_void, segment_ids: *const c_void, d_input: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_unsorted_segment_sum_backward_i64idx_f32_can_implement` (baracuda kernels unsorted segment sum backward i64idx f32 can implement).
pub fn baracuda_kernels_unsorted_segment_sum_backward_i64idx_f32_can_implement(n: i32, d: i32, num_segments: i32, d_output: *const c_void, segment_ids: *const c_void, d_input: *const c_void) -> i32;
/// `baracuda_kernels_unsorted_segment_sum_backward_i64idx_f64_run` (baracuda kernels unsorted segment sum backward i64idx f64 run).
pub fn baracuda_kernels_unsorted_segment_sum_backward_i64idx_f64_run(n: i32, d: i32, num_segments: i32, d_output: *const c_void, segment_ids: *const c_void, d_input: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_unsorted_segment_sum_backward_i64idx_f64_can_implement` (baracuda kernels unsorted segment sum backward i64idx f64 can implement).
pub fn baracuda_kernels_unsorted_segment_sum_backward_i64idx_f64_can_implement(n: i32, d: i32, num_segments: i32, d_output: *const c_void, segment_ids: *const c_void, d_input: *const c_void) -> i32;
/// `baracuda_kernels_segment_mean_backward_i64idx_f32_run` (baracuda kernels segment mean backward i64idx f32 run).
pub fn baracuda_kernels_segment_mean_backward_i64idx_f32_run(n: i32, d: i32, num_segments: i32, d_output: *const c_void, segment_ids: *const c_void, d_input: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_segment_mean_backward_i64idx_f32_can_implement` (baracuda kernels segment mean backward i64idx f32 can implement).
pub fn baracuda_kernels_segment_mean_backward_i64idx_f32_can_implement(n: i32, d: i32, num_segments: i32, d_output: *const c_void, segment_ids: *const c_void, d_input: *const c_void) -> i32;
/// `baracuda_kernels_segment_mean_backward_i64idx_f64_run` (baracuda kernels segment mean backward i64idx f64 run).
pub fn baracuda_kernels_segment_mean_backward_i64idx_f64_run(n: i32, d: i32, num_segments: i32, d_output: *const c_void, segment_ids: *const c_void, d_input: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_segment_mean_backward_i64idx_f64_can_implement` (baracuda kernels segment mean backward i64idx f64 can implement).
pub fn baracuda_kernels_segment_mean_backward_i64idx_f64_can_implement(n: i32, d: i32, num_segments: i32, d_output: *const c_void, segment_ids: *const c_void, d_input: *const c_void) -> i32;
/// `baracuda_kernels_unsorted_segment_mean_backward_i64idx_f32_run` (baracuda kernels unsorted segment mean backward i64idx f32 run).
pub fn baracuda_kernels_unsorted_segment_mean_backward_i64idx_f32_run(n: i32, d: i32, num_segments: i32, d_output: *const c_void, segment_ids: *const c_void, d_input: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_unsorted_segment_mean_backward_i64idx_f32_can_implement` (baracuda kernels unsorted segment mean backward i64idx f32 can implement).
pub fn baracuda_kernels_unsorted_segment_mean_backward_i64idx_f32_can_implement(n: i32, d: i32, num_segments: i32, d_output: *const c_void, segment_ids: *const c_void, d_input: *const c_void) -> i32;
/// `baracuda_kernels_unsorted_segment_mean_backward_i64idx_f64_run` (baracuda kernels unsorted segment mean backward i64idx f64 run).
pub fn baracuda_kernels_unsorted_segment_mean_backward_i64idx_f64_run(n: i32, d: i32, num_segments: i32, d_output: *const c_void, segment_ids: *const c_void, d_input: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_unsorted_segment_mean_backward_i64idx_f64_can_implement` (baracuda kernels unsorted segment mean backward i64idx f64 can implement).
pub fn baracuda_kernels_unsorted_segment_mean_backward_i64idx_f64_can_implement(n: i32, d: i32, num_segments: i32, d_output: *const c_void, segment_ids: *const c_void, d_input: *const c_void) -> i32;
// ---------- Phase 25: Max / Min BW (sorted + unsorted) ----------
//
// Argmax / argmin recomputed in BW — preserves FW API (no paired-
// index tensor). Signature: extra `input` pointer for the rescan.
// Tie-break: first occurrence (lowest k). PyTorch chooses last.
/// `d_input[k, d] = d_output[seg, d]` iff k is the (first)
/// max-argument of the segment in column d, else 0. Sorted seg ids.
/// f32.
pub fn baracuda_kernels_segment_max_backward_f32_run(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void,
input: *const c_void,
segment_ids: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `segment_max_backward_f32`.
pub fn baracuda_kernels_segment_max_backward_f32_can_implement(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void,
input: *const c_void,
segment_ids: *const c_void,
d_input: *const c_void,
) -> i32;
/// `segment_max_backward` — f64.
pub fn baracuda_kernels_segment_max_backward_f64_run(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void,
input: *const c_void,
segment_ids: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `segment_max_backward_f64`.
pub fn baracuda_kernels_segment_max_backward_f64_can_implement(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void,
input: *const c_void,
segment_ids: *const c_void,
d_input: *const c_void,
) -> i32;
/// `segment_min_backward` — f32.
pub fn baracuda_kernels_segment_min_backward_f32_run(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void,
input: *const c_void,
segment_ids: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `segment_min_backward_f32`.
pub fn baracuda_kernels_segment_min_backward_f32_can_implement(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void,
input: *const c_void,
segment_ids: *const c_void,
d_input: *const c_void,
) -> i32;
/// `segment_min_backward` — f64.
pub fn baracuda_kernels_segment_min_backward_f64_run(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void,
input: *const c_void,
segment_ids: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `segment_min_backward_f64`.
pub fn baracuda_kernels_segment_min_backward_f64_can_implement(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void,
input: *const c_void,
segment_ids: *const c_void,
d_input: *const c_void,
) -> i32;
/// `unsorted_segment_max_backward` — f32.
pub fn baracuda_kernels_unsorted_segment_max_backward_f32_run(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void,
input: *const c_void,
segment_ids: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `unsorted_segment_max_backward_f32`.
pub fn baracuda_kernels_unsorted_segment_max_backward_f32_can_implement(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void,
input: *const c_void,
segment_ids: *const c_void,
d_input: *const c_void,
) -> i32;
/// `unsorted_segment_max_backward` — f64.
pub fn baracuda_kernels_unsorted_segment_max_backward_f64_run(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void,
input: *const c_void,
segment_ids: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `unsorted_segment_max_backward_f64`.
pub fn baracuda_kernels_unsorted_segment_max_backward_f64_can_implement(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void,
input: *const c_void,
segment_ids: *const c_void,
d_input: *const c_void,
) -> i32;
/// `unsorted_segment_min_backward` — f32.
pub fn baracuda_kernels_unsorted_segment_min_backward_f32_run(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void,
input: *const c_void,
segment_ids: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `unsorted_segment_min_backward_f32`.
pub fn baracuda_kernels_unsorted_segment_min_backward_f32_can_implement(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void,
input: *const c_void,
segment_ids: *const c_void,
d_input: *const c_void,
) -> i32;
/// `unsorted_segment_min_backward` — f64.
pub fn baracuda_kernels_unsorted_segment_min_backward_f64_run(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void,
input: *const c_void,
segment_ids: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `unsorted_segment_min_backward_f64`.
pub fn baracuda_kernels_unsorted_segment_min_backward_f64_can_implement(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void,
input: *const c_void,
segment_ids: *const c_void,
d_input: *const c_void,
) -> i32;
// ---------- Phase 25: Prod BW (sorted + unsorted share the kernel) ----------
//
// `d_input[k, d] = d_output[seg, d] * (output[seg, d] / input[k, d])`.
// Direct division — caller must avoid zero-valued inputs or accept
// NaN / Inf in the gradient. Extra inputs: `input` and `output`
// (the saved FW `prod`).
/// `segment_prod_backward` — f32.
pub fn baracuda_kernels_segment_prod_backward_f32_run(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void,
input: *const c_void,
output: *const c_void,
segment_ids: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `segment_prod_backward_f32`.
pub fn baracuda_kernels_segment_prod_backward_f32_can_implement(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void,
input: *const c_void,
output: *const c_void,
segment_ids: *const c_void,
d_input: *const c_void,
) -> i32;
/// `segment_prod_backward` — f64.
pub fn baracuda_kernels_segment_prod_backward_f64_run(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void,
input: *const c_void,
output: *const c_void,
segment_ids: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `segment_prod_backward_f64`.
pub fn baracuda_kernels_segment_prod_backward_f64_can_implement(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void,
input: *const c_void,
output: *const c_void,
segment_ids: *const c_void,
d_input: *const c_void,
) -> i32;
/// `unsorted_segment_prod_backward` — f32. Shares the kernel with
/// the sorted variant; distinct symbol for SKU tagging.
pub fn baracuda_kernels_unsorted_segment_prod_backward_f32_run(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void,
input: *const c_void,
output: *const c_void,
segment_ids: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `unsorted_segment_prod_backward_f32`.
pub fn baracuda_kernels_unsorted_segment_prod_backward_f32_can_implement(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void,
input: *const c_void,
output: *const c_void,
segment_ids: *const c_void,
d_input: *const c_void,
) -> i32;
/// `unsorted_segment_prod_backward` — f64.
pub fn baracuda_kernels_unsorted_segment_prod_backward_f64_run(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void,
input: *const c_void,
output: *const c_void,
segment_ids: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `unsorted_segment_prod_backward_f64`.
pub fn baracuda_kernels_unsorted_segment_prod_backward_f64_can_implement(
n: i32, d: i32, num_segments: i32,
d_output: *const c_void,
input: *const c_void,
output: *const c_void,
segment_ids: *const c_void,
d_input: *const c_void,
) -> i32;
// ---------- Phase 25: Unsorted Prod FW (atomicCAS retry loop) ----------
//
// No native FP `atomicMul` — we do `atomicCAS` on the underlying
// 32 / 64-bit slot. Non-deterministic. Output pre-initialized to
// 1.0 by the launcher.
/// `unsorted_segment_prod` FW — f32.
pub fn baracuda_kernels_unsorted_segment_prod_f32_run(
n: i32, d: i32, num_segments: i32,
input: *const c_void,
segment_ids: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `unsorted_segment_prod_f32`.
pub fn baracuda_kernels_unsorted_segment_prod_f32_can_implement(
n: i32, d: i32, num_segments: i32,
input: *const c_void,
segment_ids: *const c_void,
output: *const c_void,
) -> i32;
/// `unsorted_segment_prod` FW — f64.
pub fn baracuda_kernels_unsorted_segment_prod_f64_run(
n: i32, d: i32, num_segments: i32,
input: *const c_void,
segment_ids: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `unsorted_segment_prod_f64`.
pub fn baracuda_kernels_unsorted_segment_prod_f64_can_implement(
n: i32, d: i32, num_segments: i32,
input: *const c_void,
segment_ids: *const c_void,
output: *const c_void,
) -> i32;
}
// ============================================================================
// Phase 8 Milestone 8.1 — per-tensor + per-channel quantize / dequantize
// + fake_quantize (Category P).
// ============================================================================
//
// FFI shape:
// per-tensor FW : (numel, scale, zero_point, qmin, qmax, x, q, ws, wsb, stream)
// per-tensor BW : (numel, scale, zero_point, qmin, qmax, x, dy, dx, ws, wsb, stream)
// dequant FW : (numel, scale, zero_point, q, x, ws, wsb, stream)
// dequant BW : (numel, scale, dy, dq, ws, wsb, stream)
// per-channel FW : (numel, shape4, axis, qmin, qmax, x, scale, zp, q, ws, wsb, stream)
// per-channel BW : (numel, shape4, axis, qmin, qmax, x, scale, zp, dy, dx, ws, wsb, stream)
// pc dequant FW : (numel, shape4, axis, q, scale, zp, x, ws, wsb, stream)
// pc dequant BW : (numel, shape4, axis, scale, dy, dq, ws, wsb, stream)
// fake_quantize FW : (numel, scale, zero_point, qmin, qmax, x, y, ws, wsb, stream)
// fake_quantize BW : (numel, scale, zero_point, qmin, qmax, x, dy, dx, ws, wsb, stream)
//
// The `scale` parameter is `f32` for f32 / f16 / bf16 inputs and `f64`
// for f64 inputs. The `_f64` symbol suffix marks the f64-scale flavor.
// `zero_point` is always i32. Per-channel kernels receive scale[] / zp[]
// as device pointers of length C (the extent along `axis`).
//
// STE BW convention: the "in-range mask" is recomputed in BW from the
// saved input `x` plus scale/zp — no separate mask tensor.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
// ---------- quantize_per_tensor FW (8 SKUs) ----------
/// `q = clamp(round(x/scale)+zp, qmin, qmax)`. f32 input, s8 output.
pub fn baracuda_kernels_quantize_per_tensor_f32_s8_run(
numel: i64,
scale: f32,
zero_point: i32,
q_min: i32,
q_max: i32,
x: *const c_void,
q: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_tensor_f32_s8`.
pub fn baracuda_kernels_quantize_per_tensor_f32_s8_can_implement(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, q: *const c_void,
) -> i32;
/// `quantize_per_tensor` — f32 → u8.
pub fn baracuda_kernels_quantize_per_tensor_f32_u8_run(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, q: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_tensor_f32_u8`.
pub fn baracuda_kernels_quantize_per_tensor_f32_u8_can_implement(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, q: *const c_void,
) -> i32;
/// `quantize_per_tensor` — f16 → s8.
pub fn baracuda_kernels_quantize_per_tensor_f16_s8_run(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, q: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_tensor_f16_s8`.
pub fn baracuda_kernels_quantize_per_tensor_f16_s8_can_implement(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, q: *const c_void,
) -> i32;
/// `quantize_per_tensor` — f16 → u8.
pub fn baracuda_kernels_quantize_per_tensor_f16_u8_run(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, q: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_tensor_f16_u8`.
pub fn baracuda_kernels_quantize_per_tensor_f16_u8_can_implement(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, q: *const c_void,
) -> i32;
/// `quantize_per_tensor` — bf16 → s8.
pub fn baracuda_kernels_quantize_per_tensor_bf16_s8_run(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, q: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_tensor_bf16_s8`.
pub fn baracuda_kernels_quantize_per_tensor_bf16_s8_can_implement(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, q: *const c_void,
) -> i32;
/// `quantize_per_tensor` — bf16 → u8.
pub fn baracuda_kernels_quantize_per_tensor_bf16_u8_run(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, q: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_tensor_bf16_u8`.
pub fn baracuda_kernels_quantize_per_tensor_bf16_u8_can_implement(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, q: *const c_void,
) -> i32;
/// `quantize_per_tensor` — f64 → s8 (f64 scale).
pub fn baracuda_kernels_quantize_per_tensor_f64_s8_run(
numel: i64, scale: f64, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, q: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_tensor_f64_s8`.
pub fn baracuda_kernels_quantize_per_tensor_f64_s8_can_implement(
numel: i64, scale: f64, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, q: *const c_void,
) -> i32;
/// `quantize_per_tensor` — f64 → u8 (f64 scale).
pub fn baracuda_kernels_quantize_per_tensor_f64_u8_run(
numel: i64, scale: f64, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, q: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_tensor_f64_u8`.
pub fn baracuda_kernels_quantize_per_tensor_f64_u8_can_implement(
numel: i64, scale: f64, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, q: *const c_void,
) -> i32;
// ---------- quantize_per_tensor BW (STE; 4 SKUs, FP-typed) ----------
/// `dx = (dy / scale) * in_range_mask(x)`. f32.
pub fn baracuda_kernels_quantize_per_tensor_backward_f32_run(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, dy: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_tensor_backward_f32`.
pub fn baracuda_kernels_quantize_per_tensor_backward_f32_can_implement(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, dy: *const c_void, dx: *const c_void,
) -> i32;
/// `quantize_per_tensor_backward` — f16.
pub fn baracuda_kernels_quantize_per_tensor_backward_f16_run(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, dy: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_tensor_backward_f16`.
pub fn baracuda_kernels_quantize_per_tensor_backward_f16_can_implement(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, dy: *const c_void, dx: *const c_void,
) -> i32;
/// `quantize_per_tensor_backward` — bf16.
pub fn baracuda_kernels_quantize_per_tensor_backward_bf16_run(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, dy: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_tensor_backward_bf16`.
pub fn baracuda_kernels_quantize_per_tensor_backward_bf16_can_implement(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, dy: *const c_void, dx: *const c_void,
) -> i32;
/// `quantize_per_tensor_backward` — f64 (f64 scale).
pub fn baracuda_kernels_quantize_per_tensor_backward_f64_run(
numel: i64, scale: f64, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, dy: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_tensor_backward_f64`.
pub fn baracuda_kernels_quantize_per_tensor_backward_f64_can_implement(
numel: i64, scale: f64, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, dy: *const c_void, dx: *const c_void,
) -> i32;
// ---------- dequantize_per_tensor FW (8 SKUs) ----------
/// `x = scale * (q - zp)`. s8 → f32.
pub fn baracuda_kernels_dequantize_per_tensor_f32_s8_run(
numel: i64, scale: f32, zero_point: i32,
q: *const c_void, x: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_tensor_f32_s8`.
pub fn baracuda_kernels_dequantize_per_tensor_f32_s8_can_implement(
numel: i64, scale: f32, zero_point: i32,
q: *const c_void, x: *const c_void,
) -> i32;
/// `dequantize_per_tensor` — u8 → f32.
pub fn baracuda_kernels_dequantize_per_tensor_f32_u8_run(
numel: i64, scale: f32, zero_point: i32,
q: *const c_void, x: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_tensor_f32_u8`.
pub fn baracuda_kernels_dequantize_per_tensor_f32_u8_can_implement(
numel: i64, scale: f32, zero_point: i32,
q: *const c_void, x: *const c_void,
) -> i32;
/// `dequantize_per_tensor` — s8 → f16.
pub fn baracuda_kernels_dequantize_per_tensor_f16_s8_run(
numel: i64, scale: f32, zero_point: i32,
q: *const c_void, x: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_tensor_f16_s8`.
pub fn baracuda_kernels_dequantize_per_tensor_f16_s8_can_implement(
numel: i64, scale: f32, zero_point: i32,
q: *const c_void, x: *const c_void,
) -> i32;
/// `dequantize_per_tensor` — u8 → f16.
pub fn baracuda_kernels_dequantize_per_tensor_f16_u8_run(
numel: i64, scale: f32, zero_point: i32,
q: *const c_void, x: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_tensor_f16_u8`.
pub fn baracuda_kernels_dequantize_per_tensor_f16_u8_can_implement(
numel: i64, scale: f32, zero_point: i32,
q: *const c_void, x: *const c_void,
) -> i32;
/// `dequantize_per_tensor` — s8 → bf16.
pub fn baracuda_kernels_dequantize_per_tensor_bf16_s8_run(
numel: i64, scale: f32, zero_point: i32,
q: *const c_void, x: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_tensor_bf16_s8`.
pub fn baracuda_kernels_dequantize_per_tensor_bf16_s8_can_implement(
numel: i64, scale: f32, zero_point: i32,
q: *const c_void, x: *const c_void,
) -> i32;
/// `dequantize_per_tensor` — u8 → bf16.
pub fn baracuda_kernels_dequantize_per_tensor_bf16_u8_run(
numel: i64, scale: f32, zero_point: i32,
q: *const c_void, x: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_tensor_bf16_u8`.
pub fn baracuda_kernels_dequantize_per_tensor_bf16_u8_can_implement(
numel: i64, scale: f32, zero_point: i32,
q: *const c_void, x: *const c_void,
) -> i32;
/// `dequantize_per_tensor` — s8 → f64.
pub fn baracuda_kernels_dequantize_per_tensor_f64_s8_run(
numel: i64, scale: f64, zero_point: i32,
q: *const c_void, x: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_tensor_f64_s8`.
pub fn baracuda_kernels_dequantize_per_tensor_f64_s8_can_implement(
numel: i64, scale: f64, zero_point: i32,
q: *const c_void, x: *const c_void,
) -> i32;
/// `dequantize_per_tensor` — u8 → f64.
pub fn baracuda_kernels_dequantize_per_tensor_f64_u8_run(
numel: i64, scale: f64, zero_point: i32,
q: *const c_void, x: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_tensor_f64_u8`.
pub fn baracuda_kernels_dequantize_per_tensor_f64_u8_can_implement(
numel: i64, scale: f64, zero_point: i32,
q: *const c_void, x: *const c_void,
) -> i32;
// ---------- dequantize_per_tensor BW (4 SKUs, FP-typed) ----------
/// `dq = dy * scale`. f32.
pub fn baracuda_kernels_dequantize_per_tensor_backward_f32_run(
numel: i64, scale: f32,
dy: *const c_void, dq: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_tensor_backward_f32`.
pub fn baracuda_kernels_dequantize_per_tensor_backward_f32_can_implement(
numel: i64, scale: f32,
dy: *const c_void, dq: *const c_void,
) -> i32;
/// `dequantize_per_tensor_backward` — f16.
pub fn baracuda_kernels_dequantize_per_tensor_backward_f16_run(
numel: i64, scale: f32,
dy: *const c_void, dq: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_tensor_backward_f16`.
pub fn baracuda_kernels_dequantize_per_tensor_backward_f16_can_implement(
numel: i64, scale: f32,
dy: *const c_void, dq: *const c_void,
) -> i32;
/// `dequantize_per_tensor_backward` — bf16.
pub fn baracuda_kernels_dequantize_per_tensor_backward_bf16_run(
numel: i64, scale: f32,
dy: *const c_void, dq: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_tensor_backward_bf16`.
pub fn baracuda_kernels_dequantize_per_tensor_backward_bf16_can_implement(
numel: i64, scale: f32,
dy: *const c_void, dq: *const c_void,
) -> i32;
/// `dequantize_per_tensor_backward` — f64.
pub fn baracuda_kernels_dequantize_per_tensor_backward_f64_run(
numel: i64, scale: f64,
dy: *const c_void, dq: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_tensor_backward_f64`.
pub fn baracuda_kernels_dequantize_per_tensor_backward_f64_can_implement(
numel: i64, scale: f64,
dy: *const c_void, dq: *const c_void,
) -> i32;
// ---------- fake_quantize FW + BW (4 + 4 SKUs) ----------
/// `y = scale * (clamp(round(x/scale)+zp, qmin, qmax) - zp)`. f32.
pub fn baracuda_kernels_fake_quantize_f32_run(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `fake_quantize_f32`.
pub fn baracuda_kernels_fake_quantize_f32_can_implement(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, y: *const c_void,
) -> i32;
/// `fake_quantize` — f16.
pub fn baracuda_kernels_fake_quantize_f16_run(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `fake_quantize_f16`.
pub fn baracuda_kernels_fake_quantize_f16_can_implement(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, y: *const c_void,
) -> i32;
/// `fake_quantize` — bf16.
pub fn baracuda_kernels_fake_quantize_bf16_run(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `fake_quantize_bf16`.
pub fn baracuda_kernels_fake_quantize_bf16_can_implement(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, y: *const c_void,
) -> i32;
/// `fake_quantize` — f64 (f64 scale).
pub fn baracuda_kernels_fake_quantize_f64_run(
numel: i64, scale: f64, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `fake_quantize_f64`.
pub fn baracuda_kernels_fake_quantize_f64_can_implement(
numel: i64, scale: f64, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, y: *const c_void,
) -> i32;
/// `dx = dy * in_range_mask(x)`. STE, no 1/scale factor. f32.
pub fn baracuda_kernels_fake_quantize_backward_f32_run(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, dy: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `fake_quantize_backward_f32`.
pub fn baracuda_kernels_fake_quantize_backward_f32_can_implement(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, dy: *const c_void, dx: *const c_void,
) -> i32;
/// `fake_quantize_backward` — f16.
pub fn baracuda_kernels_fake_quantize_backward_f16_run(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, dy: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `fake_quantize_backward_f16`.
pub fn baracuda_kernels_fake_quantize_backward_f16_can_implement(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, dy: *const c_void, dx: *const c_void,
) -> i32;
/// `fake_quantize_backward` — bf16.
pub fn baracuda_kernels_fake_quantize_backward_bf16_run(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, dy: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `fake_quantize_backward_bf16`.
pub fn baracuda_kernels_fake_quantize_backward_bf16_can_implement(
numel: i64, scale: f32, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, dy: *const c_void, dx: *const c_void,
) -> i32;
/// `fake_quantize_backward` — f64.
pub fn baracuda_kernels_fake_quantize_backward_f64_run(
numel: i64, scale: f64, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, dy: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `fake_quantize_backward_f64`.
pub fn baracuda_kernels_fake_quantize_backward_f64_can_implement(
numel: i64, scale: f64, zero_point: i32, q_min: i32, q_max: i32,
x: *const c_void, dy: *const c_void, dx: *const c_void,
) -> i32;
// ---------- quantize_per_channel FW (8 SKUs) ----------
//
// `shape4` is a 4-element i32 array (caller pads rank to 4 with 1's).
// `axis` selects which of the 4 dims indexes scale[] / zp[].
/// `q[i] = clamp(round(x[i]/scale[c])+zp[c], qmin, qmax)` where c =
/// coord[axis]. f32 → s8.
pub fn baracuda_kernels_quantize_per_channel_f32_s8_run(
numel: i64, shape4: *const i32, axis: i32, q_min: i32, q_max: i32,
x: *const c_void, scale: *const c_void, zero_point: *const c_void,
q: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_channel_f32_s8`.
pub fn baracuda_kernels_quantize_per_channel_f32_s8_can_implement(
numel: i64, shape4: *const i32, axis: i32, q_min: i32, q_max: i32,
x: *const c_void, scale: *const c_void, zero_point: *const c_void,
q: *const c_void,
) -> i32;
/// `quantize_per_channel` — f32 → u8.
pub fn baracuda_kernels_quantize_per_channel_f32_u8_run(
numel: i64, shape4: *const i32, axis: i32, q_min: i32, q_max: i32,
x: *const c_void, scale: *const c_void, zero_point: *const c_void,
q: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_channel_f32_u8`.
pub fn baracuda_kernels_quantize_per_channel_f32_u8_can_implement(
numel: i64, shape4: *const i32, axis: i32, q_min: i32, q_max: i32,
x: *const c_void, scale: *const c_void, zero_point: *const c_void,
q: *const c_void,
) -> i32;
/// `quantize_per_channel` — f16 → s8.
pub fn baracuda_kernels_quantize_per_channel_f16_s8_run(
numel: i64, shape4: *const i32, axis: i32, q_min: i32, q_max: i32,
x: *const c_void, scale: *const c_void, zero_point: *const c_void,
q: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_channel_f16_s8`.
pub fn baracuda_kernels_quantize_per_channel_f16_s8_can_implement(
numel: i64, shape4: *const i32, axis: i32, q_min: i32, q_max: i32,
x: *const c_void, scale: *const c_void, zero_point: *const c_void,
q: *const c_void,
) -> i32;
/// `quantize_per_channel` — f16 → u8.
pub fn baracuda_kernels_quantize_per_channel_f16_u8_run(
numel: i64, shape4: *const i32, axis: i32, q_min: i32, q_max: i32,
x: *const c_void, scale: *const c_void, zero_point: *const c_void,
q: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_channel_f16_u8`.
pub fn baracuda_kernels_quantize_per_channel_f16_u8_can_implement(
numel: i64, shape4: *const i32, axis: i32, q_min: i32, q_max: i32,
x: *const c_void, scale: *const c_void, zero_point: *const c_void,
q: *const c_void,
) -> i32;
/// `quantize_per_channel` — bf16 → s8.
pub fn baracuda_kernels_quantize_per_channel_bf16_s8_run(
numel: i64, shape4: *const i32, axis: i32, q_min: i32, q_max: i32,
x: *const c_void, scale: *const c_void, zero_point: *const c_void,
q: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_channel_bf16_s8`.
pub fn baracuda_kernels_quantize_per_channel_bf16_s8_can_implement(
numel: i64, shape4: *const i32, axis: i32, q_min: i32, q_max: i32,
x: *const c_void, scale: *const c_void, zero_point: *const c_void,
q: *const c_void,
) -> i32;
/// `quantize_per_channel` — bf16 → u8.
pub fn baracuda_kernels_quantize_per_channel_bf16_u8_run(
numel: i64, shape4: *const i32, axis: i32, q_min: i32, q_max: i32,
x: *const c_void, scale: *const c_void, zero_point: *const c_void,
q: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_channel_bf16_u8`.
pub fn baracuda_kernels_quantize_per_channel_bf16_u8_can_implement(
numel: i64, shape4: *const i32, axis: i32, q_min: i32, q_max: i32,
x: *const c_void, scale: *const c_void, zero_point: *const c_void,
q: *const c_void,
) -> i32;
/// `quantize_per_channel` — f64 → s8.
pub fn baracuda_kernels_quantize_per_channel_f64_s8_run(
numel: i64, shape4: *const i32, axis: i32, q_min: i32, q_max: i32,
x: *const c_void, scale: *const c_void, zero_point: *const c_void,
q: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_channel_f64_s8`.
pub fn baracuda_kernels_quantize_per_channel_f64_s8_can_implement(
numel: i64, shape4: *const i32, axis: i32, q_min: i32, q_max: i32,
x: *const c_void, scale: *const c_void, zero_point: *const c_void,
q: *const c_void,
) -> i32;
/// `quantize_per_channel` — f64 → u8.
pub fn baracuda_kernels_quantize_per_channel_f64_u8_run(
numel: i64, shape4: *const i32, axis: i32, q_min: i32, q_max: i32,
x: *const c_void, scale: *const c_void, zero_point: *const c_void,
q: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_channel_f64_u8`.
pub fn baracuda_kernels_quantize_per_channel_f64_u8_can_implement(
numel: i64, shape4: *const i32, axis: i32, q_min: i32, q_max: i32,
x: *const c_void, scale: *const c_void, zero_point: *const c_void,
q: *const c_void,
) -> i32;
// ---------- quantize_per_channel BW (4 SKUs, STE) ----------
/// `dx[i] = (dy[i] / scale[c]) * in_range_mask(x[i])`. f32.
pub fn baracuda_kernels_quantize_per_channel_backward_f32_run(
numel: i64, shape4: *const i32, axis: i32, q_min: i32, q_max: i32,
x: *const c_void, scale: *const c_void, zero_point: *const c_void,
dy: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_channel_backward_f32`.
pub fn baracuda_kernels_quantize_per_channel_backward_f32_can_implement(
numel: i64, shape4: *const i32, axis: i32, q_min: i32, q_max: i32,
x: *const c_void, scale: *const c_void, zero_point: *const c_void,
dy: *const c_void, dx: *const c_void,
) -> i32;
/// `quantize_per_channel_backward` — f16.
pub fn baracuda_kernels_quantize_per_channel_backward_f16_run(
numel: i64, shape4: *const i32, axis: i32, q_min: i32, q_max: i32,
x: *const c_void, scale: *const c_void, zero_point: *const c_void,
dy: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_channel_backward_f16`.
pub fn baracuda_kernels_quantize_per_channel_backward_f16_can_implement(
numel: i64, shape4: *const i32, axis: i32, q_min: i32, q_max: i32,
x: *const c_void, scale: *const c_void, zero_point: *const c_void,
dy: *const c_void, dx: *const c_void,
) -> i32;
/// `quantize_per_channel_backward` — bf16.
pub fn baracuda_kernels_quantize_per_channel_backward_bf16_run(
numel: i64, shape4: *const i32, axis: i32, q_min: i32, q_max: i32,
x: *const c_void, scale: *const c_void, zero_point: *const c_void,
dy: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_channel_backward_bf16`.
pub fn baracuda_kernels_quantize_per_channel_backward_bf16_can_implement(
numel: i64, shape4: *const i32, axis: i32, q_min: i32, q_max: i32,
x: *const c_void, scale: *const c_void, zero_point: *const c_void,
dy: *const c_void, dx: *const c_void,
) -> i32;
/// `quantize_per_channel_backward` — f64.
pub fn baracuda_kernels_quantize_per_channel_backward_f64_run(
numel: i64, shape4: *const i32, axis: i32, q_min: i32, q_max: i32,
x: *const c_void, scale: *const c_void, zero_point: *const c_void,
dy: *const c_void, dx: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_channel_backward_f64`.
pub fn baracuda_kernels_quantize_per_channel_backward_f64_can_implement(
numel: i64, shape4: *const i32, axis: i32, q_min: i32, q_max: i32,
x: *const c_void, scale: *const c_void, zero_point: *const c_void,
dy: *const c_void, dx: *const c_void,
) -> i32;
// ---------- dequantize_per_channel FW (8 SKUs) ----------
/// `x[i] = scale[c] * (q[i] - zp[c])`. s8 → f32.
pub fn baracuda_kernels_dequantize_per_channel_f32_s8_run(
numel: i64, shape4: *const i32, axis: i32,
q: *const c_void, scale: *const c_void, zero_point: *const c_void,
x: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_channel_f32_s8`.
pub fn baracuda_kernels_dequantize_per_channel_f32_s8_can_implement(
numel: i64, shape4: *const i32, axis: i32,
q: *const c_void, scale: *const c_void, zero_point: *const c_void,
x: *const c_void,
) -> i32;
/// `dequantize_per_channel` — u8 → f32.
pub fn baracuda_kernels_dequantize_per_channel_f32_u8_run(
numel: i64, shape4: *const i32, axis: i32,
q: *const c_void, scale: *const c_void, zero_point: *const c_void,
x: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_channel_f32_u8`.
pub fn baracuda_kernels_dequantize_per_channel_f32_u8_can_implement(
numel: i64, shape4: *const i32, axis: i32,
q: *const c_void, scale: *const c_void, zero_point: *const c_void,
x: *const c_void,
) -> i32;
/// `dequantize_per_channel` — s8 → f16.
pub fn baracuda_kernels_dequantize_per_channel_f16_s8_run(
numel: i64, shape4: *const i32, axis: i32,
q: *const c_void, scale: *const c_void, zero_point: *const c_void,
x: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_channel_f16_s8`.
pub fn baracuda_kernels_dequantize_per_channel_f16_s8_can_implement(
numel: i64, shape4: *const i32, axis: i32,
q: *const c_void, scale: *const c_void, zero_point: *const c_void,
x: *const c_void,
) -> i32;
/// `dequantize_per_channel` — u8 → f16.
pub fn baracuda_kernels_dequantize_per_channel_f16_u8_run(
numel: i64, shape4: *const i32, axis: i32,
q: *const c_void, scale: *const c_void, zero_point: *const c_void,
x: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_channel_f16_u8`.
pub fn baracuda_kernels_dequantize_per_channel_f16_u8_can_implement(
numel: i64, shape4: *const i32, axis: i32,
q: *const c_void, scale: *const c_void, zero_point: *const c_void,
x: *const c_void,
) -> i32;
/// `dequantize_per_channel` — s8 → bf16.
pub fn baracuda_kernels_dequantize_per_channel_bf16_s8_run(
numel: i64, shape4: *const i32, axis: i32,
q: *const c_void, scale: *const c_void, zero_point: *const c_void,
x: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_channel_bf16_s8`.
pub fn baracuda_kernels_dequantize_per_channel_bf16_s8_can_implement(
numel: i64, shape4: *const i32, axis: i32,
q: *const c_void, scale: *const c_void, zero_point: *const c_void,
x: *const c_void,
) -> i32;
/// `dequantize_per_channel` — u8 → bf16.
pub fn baracuda_kernels_dequantize_per_channel_bf16_u8_run(
numel: i64, shape4: *const i32, axis: i32,
q: *const c_void, scale: *const c_void, zero_point: *const c_void,
x: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_channel_bf16_u8`.
pub fn baracuda_kernels_dequantize_per_channel_bf16_u8_can_implement(
numel: i64, shape4: *const i32, axis: i32,
q: *const c_void, scale: *const c_void, zero_point: *const c_void,
x: *const c_void,
) -> i32;
/// `dequantize_per_channel` — s8 → f64.
pub fn baracuda_kernels_dequantize_per_channel_f64_s8_run(
numel: i64, shape4: *const i32, axis: i32,
q: *const c_void, scale: *const c_void, zero_point: *const c_void,
x: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_channel_f64_s8`.
pub fn baracuda_kernels_dequantize_per_channel_f64_s8_can_implement(
numel: i64, shape4: *const i32, axis: i32,
q: *const c_void, scale: *const c_void, zero_point: *const c_void,
x: *const c_void,
) -> i32;
/// `dequantize_per_channel` — u8 → f64.
pub fn baracuda_kernels_dequantize_per_channel_f64_u8_run(
numel: i64, shape4: *const i32, axis: i32,
q: *const c_void, scale: *const c_void, zero_point: *const c_void,
x: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_channel_f64_u8`.
pub fn baracuda_kernels_dequantize_per_channel_f64_u8_can_implement(
numel: i64, shape4: *const i32, axis: i32,
q: *const c_void, scale: *const c_void, zero_point: *const c_void,
x: *const c_void,
) -> i32;
// ---------- dequantize_per_channel BW (4 SKUs) ----------
/// `dq[i] = dy[i] * scale[c]`. f32.
pub fn baracuda_kernels_dequantize_per_channel_backward_f32_run(
numel: i64, shape4: *const i32, axis: i32,
scale: *const c_void, dy: *const c_void, dq: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_channel_backward_f32`.
pub fn baracuda_kernels_dequantize_per_channel_backward_f32_can_implement(
numel: i64, shape4: *const i32, axis: i32,
scale: *const c_void, dy: *const c_void, dq: *const c_void,
) -> i32;
/// `dequantize_per_channel_backward` — f16.
pub fn baracuda_kernels_dequantize_per_channel_backward_f16_run(
numel: i64, shape4: *const i32, axis: i32,
scale: *const c_void, dy: *const c_void, dq: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_channel_backward_f16`.
pub fn baracuda_kernels_dequantize_per_channel_backward_f16_can_implement(
numel: i64, shape4: *const i32, axis: i32,
scale: *const c_void, dy: *const c_void, dq: *const c_void,
) -> i32;
/// `dequantize_per_channel_backward` — bf16.
pub fn baracuda_kernels_dequantize_per_channel_backward_bf16_run(
numel: i64, shape4: *const i32, axis: i32,
scale: *const c_void, dy: *const c_void, dq: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_channel_backward_bf16`.
pub fn baracuda_kernels_dequantize_per_channel_backward_bf16_can_implement(
numel: i64, shape4: *const i32, axis: i32,
scale: *const c_void, dy: *const c_void, dq: *const c_void,
) -> i32;
/// `dequantize_per_channel_backward` — f64.
pub fn baracuda_kernels_dequantize_per_channel_backward_f64_run(
numel: i64, shape4: *const i32, axis: i32,
scale: *const c_void, dy: *const c_void, dq: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_channel_backward_f64`.
pub fn baracuda_kernels_dequantize_per_channel_backward_f64_can_implement(
numel: i64, shape4: *const i32, axis: i32,
scale: *const c_void, dy: *const c_void, dq: *const c_void,
) -> i32;
}
// ============================================================================
// Phase 8 Milestone 8.2 — per-token + per-group quantize / dequantize
// (Category P, LLM/GPTQ-style)
// ============================================================================
//
// Sibling Milestone 8.1 (per-tensor / per-channel / fake_quantize) ships
// its own `unsafe extern "C" { ... }` block; this one carries only the
// per-token / per-group symbols so the two milestones never collide.
//
// FW signature (per-token):
// (n, d, qmin, qmax, input, scale, zero_point, output, ws, ws_bytes, stream)
// FW signature (per-group):
// (outer, axis_size, group_size, qmin, qmax,
// input, scale, zero_point, output, ws, ws_bytes, stream)
// BW signatures append `input` (per-token quant BW) and drop `qmin/qmax`
// for dequant BW (straight-through, no clipping). Workspace is unused —
// these are pure pointwise kernels.
unsafe extern "C" {
// ---------- quantize_per_token forward × 8 ----------
/// `quantize_per_token` — TIn f32, TOut s8. Status codes as elsewhere.
pub fn baracuda_kernels_quantize_per_token_f32_s8_run(
n: i32,
d: i32,
qmin: i32,
qmax: i32,
input: *const c_void,
scale: *const c_void,
zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_token_f32_s8`.
pub fn baracuda_kernels_quantize_per_token_f32_s8_can_implement(
n: i32, d: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
/// `quantize_per_token` — f32 → u8.
pub fn baracuda_kernels_quantize_per_token_f32_u8_run(
n: i32, d: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_token_f32_u8`.
pub fn baracuda_kernels_quantize_per_token_f32_u8_can_implement(
n: i32, d: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
/// `quantize_per_token` — f64 → s8.
pub fn baracuda_kernels_quantize_per_token_f64_s8_run(
n: i32, d: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_token_f64_s8`.
pub fn baracuda_kernels_quantize_per_token_f64_s8_can_implement(
n: i32, d: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
/// `quantize_per_token` — f64 → u8.
pub fn baracuda_kernels_quantize_per_token_f64_u8_run(
n: i32, d: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_token_f64_u8`.
pub fn baracuda_kernels_quantize_per_token_f64_u8_can_implement(
n: i32, d: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
/// `quantize_per_token` — f16 → s8.
pub fn baracuda_kernels_quantize_per_token_f16_s8_run(
n: i32, d: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_token_f16_s8`.
pub fn baracuda_kernels_quantize_per_token_f16_s8_can_implement(
n: i32, d: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
/// `quantize_per_token` — f16 → u8.
pub fn baracuda_kernels_quantize_per_token_f16_u8_run(
n: i32, d: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_token_f16_u8`.
pub fn baracuda_kernels_quantize_per_token_f16_u8_can_implement(
n: i32, d: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
/// `quantize_per_token` — bf16 → s8.
pub fn baracuda_kernels_quantize_per_token_bf16_s8_run(
n: i32, d: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_token_bf16_s8`.
pub fn baracuda_kernels_quantize_per_token_bf16_s8_can_implement(
n: i32, d: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
/// `quantize_per_token` — bf16 → u8.
pub fn baracuda_kernels_quantize_per_token_bf16_u8_run(
n: i32, d: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_token_bf16_u8`.
pub fn baracuda_kernels_quantize_per_token_bf16_u8_can_implement(
n: i32, d: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
// ---------- quantize_per_token backward × 4 (STE) ----------
/// STE backward — f32.
pub fn baracuda_kernels_quantize_per_token_backward_f32_run(
n: i32, d: i32, qmin: i32, qmax: i32,
d_output: *const c_void,
input: *const c_void,
scale: *const c_void,
zero_point: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_token_backward_f32`.
pub fn baracuda_kernels_quantize_per_token_backward_f32_can_implement(
n: i32, d: i32, qmin: i32, qmax: i32,
d_output: *const c_void, input: *const c_void,
scale: *const c_void, zero_point: *const c_void,
d_input: *const c_void,
) -> i32;
/// STE backward — f64.
pub fn baracuda_kernels_quantize_per_token_backward_f64_run(
n: i32, d: i32, qmin: i32, qmax: i32,
d_output: *const c_void, input: *const c_void,
scale: *const c_void, zero_point: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_token_backward_f64`.
pub fn baracuda_kernels_quantize_per_token_backward_f64_can_implement(
n: i32, d: i32, qmin: i32, qmax: i32,
d_output: *const c_void, input: *const c_void,
scale: *const c_void, zero_point: *const c_void,
d_input: *const c_void,
) -> i32;
/// STE backward — f16.
pub fn baracuda_kernels_quantize_per_token_backward_f16_run(
n: i32, d: i32, qmin: i32, qmax: i32,
d_output: *const c_void, input: *const c_void,
scale: *const c_void, zero_point: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_token_backward_f16`.
pub fn baracuda_kernels_quantize_per_token_backward_f16_can_implement(
n: i32, d: i32, qmin: i32, qmax: i32,
d_output: *const c_void, input: *const c_void,
scale: *const c_void, zero_point: *const c_void,
d_input: *const c_void,
) -> i32;
/// STE backward — bf16.
pub fn baracuda_kernels_quantize_per_token_backward_bf16_run(
n: i32, d: i32, qmin: i32, qmax: i32,
d_output: *const c_void, input: *const c_void,
scale: *const c_void, zero_point: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_token_backward_bf16`.
pub fn baracuda_kernels_quantize_per_token_backward_bf16_can_implement(
n: i32, d: i32, qmin: i32, qmax: i32,
d_output: *const c_void, input: *const c_void,
scale: *const c_void, zero_point: *const c_void,
d_input: *const c_void,
) -> i32;
// ---------- dequantize_per_token forward × 8 ----------
/// `dequantize_per_token` — q s8 → y f32.
pub fn baracuda_kernels_dequantize_per_token_f32_s8_run(
n: i32, d: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_token_f32_s8`.
pub fn baracuda_kernels_dequantize_per_token_f32_s8_can_implement(
n: i32, d: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
/// `dequantize_per_token` — q u8 → y f32.
pub fn baracuda_kernels_dequantize_per_token_f32_u8_run(
n: i32, d: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_token_f32_u8`.
pub fn baracuda_kernels_dequantize_per_token_f32_u8_can_implement(
n: i32, d: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
/// `dequantize_per_token` — q s8 → y f64.
pub fn baracuda_kernels_dequantize_per_token_f64_s8_run(
n: i32, d: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_token_f64_s8`.
pub fn baracuda_kernels_dequantize_per_token_f64_s8_can_implement(
n: i32, d: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
/// `dequantize_per_token` — q u8 → y f64.
pub fn baracuda_kernels_dequantize_per_token_f64_u8_run(
n: i32, d: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_token_f64_u8`.
pub fn baracuda_kernels_dequantize_per_token_f64_u8_can_implement(
n: i32, d: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
/// `dequantize_per_token` — q s8 → y f16.
pub fn baracuda_kernels_dequantize_per_token_f16_s8_run(
n: i32, d: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_token_f16_s8`.
pub fn baracuda_kernels_dequantize_per_token_f16_s8_can_implement(
n: i32, d: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
/// `dequantize_per_token` — q u8 → y f16.
pub fn baracuda_kernels_dequantize_per_token_f16_u8_run(
n: i32, d: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_token_f16_u8`.
pub fn baracuda_kernels_dequantize_per_token_f16_u8_can_implement(
n: i32, d: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
/// `dequantize_per_token` — q s8 → y bf16.
pub fn baracuda_kernels_dequantize_per_token_bf16_s8_run(
n: i32, d: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_token_bf16_s8`.
pub fn baracuda_kernels_dequantize_per_token_bf16_s8_can_implement(
n: i32, d: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
/// `dequantize_per_token` — q u8 → y bf16.
pub fn baracuda_kernels_dequantize_per_token_bf16_u8_run(
n: i32, d: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_token_bf16_u8`.
pub fn baracuda_kernels_dequantize_per_token_bf16_u8_can_implement(
n: i32, d: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
// ---------- dequantize_per_token backward × 4 (straight-through) ----------
/// Dequant BW — f32.
pub fn baracuda_kernels_dequantize_per_token_backward_f32_run(
n: i32, d: i32,
d_output: *const c_void, scale: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_token_backward_f32`.
pub fn baracuda_kernels_dequantize_per_token_backward_f32_can_implement(
n: i32, d: i32,
d_output: *const c_void, scale: *const c_void,
d_input: *const c_void,
) -> i32;
/// Dequant BW — f64.
pub fn baracuda_kernels_dequantize_per_token_backward_f64_run(
n: i32, d: i32,
d_output: *const c_void, scale: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_token_backward_f64`.
pub fn baracuda_kernels_dequantize_per_token_backward_f64_can_implement(
n: i32, d: i32,
d_output: *const c_void, scale: *const c_void,
d_input: *const c_void,
) -> i32;
/// Dequant BW — f16.
pub fn baracuda_kernels_dequantize_per_token_backward_f16_run(
n: i32, d: i32,
d_output: *const c_void, scale: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_token_backward_f16`.
pub fn baracuda_kernels_dequantize_per_token_backward_f16_can_implement(
n: i32, d: i32,
d_output: *const c_void, scale: *const c_void,
d_input: *const c_void,
) -> i32;
/// Dequant BW — bf16.
pub fn baracuda_kernels_dequantize_per_token_backward_bf16_run(
n: i32, d: i32,
d_output: *const c_void, scale: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_token_backward_bf16`.
pub fn baracuda_kernels_dequantize_per_token_backward_bf16_can_implement(
n: i32, d: i32,
d_output: *const c_void, scale: *const c_void,
d_input: *const c_void,
) -> i32;
// ---------- quantize_per_group forward × 8 ----------
/// `quantize_per_group` — f32 → s8.
pub fn baracuda_kernels_quantize_per_group_f32_s8_run(
outer: i32, axis_size: i32, group_size: i32,
qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_group_f32_s8`.
pub fn baracuda_kernels_quantize_per_group_f32_s8_can_implement(
outer: i32, axis_size: i32, group_size: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
/// `quantize_per_group` — f32 → u8.
pub fn baracuda_kernels_quantize_per_group_f32_u8_run(
outer: i32, axis_size: i32, group_size: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_group_f32_u8`.
pub fn baracuda_kernels_quantize_per_group_f32_u8_can_implement(
outer: i32, axis_size: i32, group_size: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
/// `quantize_per_group` — f64 → s8.
pub fn baracuda_kernels_quantize_per_group_f64_s8_run(
outer: i32, axis_size: i32, group_size: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_group_f64_s8`.
pub fn baracuda_kernels_quantize_per_group_f64_s8_can_implement(
outer: i32, axis_size: i32, group_size: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
/// `quantize_per_group` — f64 → u8.
pub fn baracuda_kernels_quantize_per_group_f64_u8_run(
outer: i32, axis_size: i32, group_size: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_group_f64_u8`.
pub fn baracuda_kernels_quantize_per_group_f64_u8_can_implement(
outer: i32, axis_size: i32, group_size: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
/// `quantize_per_group` — f16 → s8.
pub fn baracuda_kernels_quantize_per_group_f16_s8_run(
outer: i32, axis_size: i32, group_size: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_group_f16_s8`.
pub fn baracuda_kernels_quantize_per_group_f16_s8_can_implement(
outer: i32, axis_size: i32, group_size: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
/// `quantize_per_group` — f16 → u8.
pub fn baracuda_kernels_quantize_per_group_f16_u8_run(
outer: i32, axis_size: i32, group_size: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_group_f16_u8`.
pub fn baracuda_kernels_quantize_per_group_f16_u8_can_implement(
outer: i32, axis_size: i32, group_size: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
/// `quantize_per_group` — bf16 → s8.
pub fn baracuda_kernels_quantize_per_group_bf16_s8_run(
outer: i32, axis_size: i32, group_size: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_group_bf16_s8`.
pub fn baracuda_kernels_quantize_per_group_bf16_s8_can_implement(
outer: i32, axis_size: i32, group_size: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
/// `quantize_per_group` — bf16 → u8.
pub fn baracuda_kernels_quantize_per_group_bf16_u8_run(
outer: i32, axis_size: i32, group_size: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_group_bf16_u8`.
pub fn baracuda_kernels_quantize_per_group_bf16_u8_can_implement(
outer: i32, axis_size: i32, group_size: i32, qmin: i32, qmax: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
// ---------- quantize_per_group backward × 4 (STE) ----------
/// STE BW — f32.
pub fn baracuda_kernels_quantize_per_group_backward_f32_run(
outer: i32, axis_size: i32, group_size: i32, qmin: i32, qmax: i32,
d_output: *const c_void, input: *const c_void,
scale: *const c_void, zero_point: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_group_backward_f32`.
pub fn baracuda_kernels_quantize_per_group_backward_f32_can_implement(
outer: i32, axis_size: i32, group_size: i32, qmin: i32, qmax: i32,
d_output: *const c_void, input: *const c_void,
scale: *const c_void, zero_point: *const c_void,
d_input: *const c_void,
) -> i32;
/// STE BW — f64.
pub fn baracuda_kernels_quantize_per_group_backward_f64_run(
outer: i32, axis_size: i32, group_size: i32, qmin: i32, qmax: i32,
d_output: *const c_void, input: *const c_void,
scale: *const c_void, zero_point: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_group_backward_f64`.
pub fn baracuda_kernels_quantize_per_group_backward_f64_can_implement(
outer: i32, axis_size: i32, group_size: i32, qmin: i32, qmax: i32,
d_output: *const c_void, input: *const c_void,
scale: *const c_void, zero_point: *const c_void,
d_input: *const c_void,
) -> i32;
/// STE BW — f16.
pub fn baracuda_kernels_quantize_per_group_backward_f16_run(
outer: i32, axis_size: i32, group_size: i32, qmin: i32, qmax: i32,
d_output: *const c_void, input: *const c_void,
scale: *const c_void, zero_point: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_group_backward_f16`.
pub fn baracuda_kernels_quantize_per_group_backward_f16_can_implement(
outer: i32, axis_size: i32, group_size: i32, qmin: i32, qmax: i32,
d_output: *const c_void, input: *const c_void,
scale: *const c_void, zero_point: *const c_void,
d_input: *const c_void,
) -> i32;
/// STE BW — bf16.
pub fn baracuda_kernels_quantize_per_group_backward_bf16_run(
outer: i32, axis_size: i32, group_size: i32, qmin: i32, qmax: i32,
d_output: *const c_void, input: *const c_void,
scale: *const c_void, zero_point: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantize_per_group_backward_bf16`.
pub fn baracuda_kernels_quantize_per_group_backward_bf16_can_implement(
outer: i32, axis_size: i32, group_size: i32, qmin: i32, qmax: i32,
d_output: *const c_void, input: *const c_void,
scale: *const c_void, zero_point: *const c_void,
d_input: *const c_void,
) -> i32;
// ---------- dequantize_per_group forward × 8 ----------
/// Dequant — f32, s8.
pub fn baracuda_kernels_dequantize_per_group_f32_s8_run(
outer: i32, axis_size: i32, group_size: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_group_f32_s8`.
pub fn baracuda_kernels_dequantize_per_group_f32_s8_can_implement(
outer: i32, axis_size: i32, group_size: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
/// Dequant — f32, u8.
pub fn baracuda_kernels_dequantize_per_group_f32_u8_run(
outer: i32, axis_size: i32, group_size: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_group_f32_u8`.
pub fn baracuda_kernels_dequantize_per_group_f32_u8_can_implement(
outer: i32, axis_size: i32, group_size: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
/// Dequant — f64, s8.
pub fn baracuda_kernels_dequantize_per_group_f64_s8_run(
outer: i32, axis_size: i32, group_size: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_group_f64_s8`.
pub fn baracuda_kernels_dequantize_per_group_f64_s8_can_implement(
outer: i32, axis_size: i32, group_size: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
/// Dequant — f64, u8.
pub fn baracuda_kernels_dequantize_per_group_f64_u8_run(
outer: i32, axis_size: i32, group_size: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_group_f64_u8`.
pub fn baracuda_kernels_dequantize_per_group_f64_u8_can_implement(
outer: i32, axis_size: i32, group_size: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
/// Dequant — f16, s8.
pub fn baracuda_kernels_dequantize_per_group_f16_s8_run(
outer: i32, axis_size: i32, group_size: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_group_f16_s8`.
pub fn baracuda_kernels_dequantize_per_group_f16_s8_can_implement(
outer: i32, axis_size: i32, group_size: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
/// Dequant — f16, u8.
pub fn baracuda_kernels_dequantize_per_group_f16_u8_run(
outer: i32, axis_size: i32, group_size: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_group_f16_u8`.
pub fn baracuda_kernels_dequantize_per_group_f16_u8_can_implement(
outer: i32, axis_size: i32, group_size: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
/// Dequant — bf16, s8.
pub fn baracuda_kernels_dequantize_per_group_bf16_s8_run(
outer: i32, axis_size: i32, group_size: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_group_bf16_s8`.
pub fn baracuda_kernels_dequantize_per_group_bf16_s8_can_implement(
outer: i32, axis_size: i32, group_size: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
/// Dequant — bf16, u8.
pub fn baracuda_kernels_dequantize_per_group_bf16_u8_run(
outer: i32, axis_size: i32, group_size: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_group_bf16_u8`.
pub fn baracuda_kernels_dequantize_per_group_bf16_u8_can_implement(
outer: i32, axis_size: i32, group_size: i32,
input: *const c_void, scale: *const c_void, zero_point: *const c_void,
output: *const c_void,
) -> i32;
// ---------- dequantize_per_group backward × 4 (straight-through) ----------
/// Dequant BW — f32.
pub fn baracuda_kernels_dequantize_per_group_backward_f32_run(
outer: i32, axis_size: i32, group_size: i32,
d_output: *const c_void, scale: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_group_backward_f32`.
pub fn baracuda_kernels_dequantize_per_group_backward_f32_can_implement(
outer: i32, axis_size: i32, group_size: i32,
d_output: *const c_void, scale: *const c_void,
d_input: *const c_void,
) -> i32;
/// Dequant BW — f64.
pub fn baracuda_kernels_dequantize_per_group_backward_f64_run(
outer: i32, axis_size: i32, group_size: i32,
d_output: *const c_void, scale: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_group_backward_f64`.
pub fn baracuda_kernels_dequantize_per_group_backward_f64_can_implement(
outer: i32, axis_size: i32, group_size: i32,
d_output: *const c_void, scale: *const c_void,
d_input: *const c_void,
) -> i32;
/// Dequant BW — f16.
pub fn baracuda_kernels_dequantize_per_group_backward_f16_run(
outer: i32, axis_size: i32, group_size: i32,
d_output: *const c_void, scale: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_group_backward_f16`.
pub fn baracuda_kernels_dequantize_per_group_backward_f16_can_implement(
outer: i32, axis_size: i32, group_size: i32,
d_output: *const c_void, scale: *const c_void,
d_input: *const c_void,
) -> i32;
/// Dequant BW — bf16.
pub fn baracuda_kernels_dequantize_per_group_backward_bf16_run(
outer: i32, axis_size: i32, group_size: i32,
d_output: *const c_void, scale: *const c_void,
d_input: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dequantize_per_group_backward_bf16`.
pub fn baracuda_kernels_dequantize_per_group_backward_bf16_can_implement(
outer: i32, axis_size: i32, group_size: i32,
d_output: *const c_void, scale: *const c_void,
d_input: *const c_void,
) -> i32;
// ---------- Milestone 8.3 — composing quantization ops ----------
//
// Two op families:
// 1. `dynamic_range_quantize_per_token_sym` — fused per-row
// max-abs + symmetric scale + per-token quantize. Writes the
// computed `scale[N]` vector alongside the quantized output.
// 2. `quantized_linear_w8a8` — naive W8A8 quantized matmul.
// Input: already-quantized int8 activation `[M, K]` +
// already-quantized int8 weight `[C_out, K]` + per-row
// `scale_a[M]` + per-channel `scale_w[C_out]`. Output: FP
// `[M, C_out]`. Trailblazer = correctness scaffold, not
// perf-optimized.
//
// Trailblazer dtype coverage: TIn ∈ {f32, f64}, TOut weight/act = s8.
/// `dynamic_range_quantize_per_token_sym` — f32 → s8.
pub fn baracuda_kernels_dynamic_range_quantize_per_token_sym_f32_s8_run(
n: i32, d: i32, qmin: i32, qmax: i32,
input: *const c_void,
scale: *mut c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dynamic_range_quantize_per_token_sym_f32_s8`.
pub fn baracuda_kernels_dynamic_range_quantize_per_token_sym_f32_s8_can_implement(
n: i32, d: i32, qmin: i32, qmax: i32,
input: *const c_void,
scale: *const c_void,
output: *const c_void,
) -> i32;
/// `dynamic_range_quantize_per_token_sym` — f64 → s8.
pub fn baracuda_kernels_dynamic_range_quantize_per_token_sym_f64_s8_run(
n: i32, d: i32, qmin: i32, qmax: i32,
input: *const c_void,
scale: *mut c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `dynamic_range_quantize_per_token_sym_f64_s8`.
pub fn baracuda_kernels_dynamic_range_quantize_per_token_sym_f64_s8_can_implement(
n: i32, d: i32, qmin: i32, qmax: i32,
input: *const c_void,
scale: *const c_void,
output: *const c_void,
) -> i32;
/// `quantized_linear_w8a8` — TIn = f32.
pub fn baracuda_kernels_quantized_linear_w8a8_f32_run(
m: i32, c_out: i32, k: i32,
weight_q: *const c_void,
act_q: *const c_void,
scale_a: *const c_void,
scale_w: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantized_linear_w8a8_f32`.
pub fn baracuda_kernels_quantized_linear_w8a8_f32_can_implement(
m: i32, c_out: i32, k: i32,
weight_q: *const c_void,
act_q: *const c_void,
scale_a: *const c_void,
scale_w: *const c_void,
output: *const c_void,
) -> i32;
/// `quantized_linear_w8a8` — TIn = f64.
pub fn baracuda_kernels_quantized_linear_w8a8_f64_run(
m: i32, c_out: i32, k: i32,
weight_q: *const c_void,
act_q: *const c_void,
scale_a: *const c_void,
scale_w: *const c_void,
output: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// Implementability check for `quantized_linear_w8a8_f64`.
pub fn baracuda_kernels_quantized_linear_w8a8_f64_can_implement(
m: i32, c_out: i32, k: i32,
weight_q: *const c_void,
act_q: *const c_void,
scale_a: *const c_void,
scale_w: *const c_void,
output: *const c_void,
) -> i32;
}
// ============================================================================
// Cast / Fill / Affine — Phase 3 fanout from fuel-cuda-kernels
// ============================================================================
//
// Vendored / adapted from `fuel-cuda-kernels/src/{cast,fill,affine}.cu`.
// See `crates/baracuda-kernels-sys/LICENSE-thirdparty.md`. Contig-only
// fast path; baracuda's plan layer materializes strided views upstream.
//
// Status codes mirror the GEMM family (see crate-level doc).
// Cast — explicit per-pair declarations (no macro to keep no_std + no
// proc-macro deps; each pair is just two trivial `pub fn` lines).
//
// Safety contract shared by all cast `_run` / `_can_implement` pairs:
// - `x` and `y` must each point to at least `numel` elements of device
// memory in the input / output element type respectively.
// - `stream` must be a live CUDA stream in the current context.
// - `_can_implement` performs host-side checks only.
//
// **Aliasing (Phase 64)**: aliasing `y` with `x` is safe IF AND ONLY IF
// `sizeof(TIn) == sizeof(TOut)` (same byte width across the cast). The
// kernel body is `y[i] = cast(x[i])` — each thread reads its own input
// cell, then writes its own output cell. With matching byte widths,
// both reads and writes share an address (per i), so per-thread access
// is structurally safe regardless of `__restrict__`. With differing
// byte widths, the output element at index i lives at a byte offset
// `i * sizeof(TOut)` that overlaps OTHER threads' input cells at
// `j * sizeof(TIn)`, producing a data race. Same-width casts that
// are aliasing-safe today include f32↔i32, f32↔u32, i32↔u32, f16↔bf16,
// f64↔i64, f64↔u64, u8↔i8 (and identity casts cast_<T>_<T>). All
// other casts (e.g. f32→f64, f16→f32) are NOT aliasing-safe.
// This contract is stable across baracuda versions.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
// f32 -> *
/// Cast `f32 -> f32`. See `LICENSE-thirdparty.md`.
pub fn baracuda_kernels_cast_f32_f32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_f32_f32`.
pub fn baracuda_kernels_cast_f32_f32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f32 -> f64`.
pub fn baracuda_kernels_cast_f32_f64_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_f32_f64`.
pub fn baracuda_kernels_cast_f32_f64_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f32 -> f16`.
pub fn baracuda_kernels_cast_f32_f16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_f32_f16`.
pub fn baracuda_kernels_cast_f32_f16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f32 -> bf16`.
pub fn baracuda_kernels_cast_f32_bf16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_f32_bf16`.
pub fn baracuda_kernels_cast_f32_bf16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f32 -> i32`.
pub fn baracuda_kernels_cast_f32_i32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_f32_i32`.
pub fn baracuda_kernels_cast_f32_i32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f32 -> i64`.
pub fn baracuda_kernels_cast_f32_i64_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_f32_i64`.
pub fn baracuda_kernels_cast_f32_i64_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f32 -> u8`.
pub fn baracuda_kernels_cast_f32_u8_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_f32_u8`.
pub fn baracuda_kernels_cast_f32_u8_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f32 -> i8`.
pub fn baracuda_kernels_cast_f32_i8_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_f32_i8`.
pub fn baracuda_kernels_cast_f32_i8_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
// f64 -> *
/// Cast `f64 -> f32`.
pub fn baracuda_kernels_cast_f64_f32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_f64_f32`.
pub fn baracuda_kernels_cast_f64_f32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f64 -> f64`.
pub fn baracuda_kernels_cast_f64_f64_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_f64_f64`.
pub fn baracuda_kernels_cast_f64_f64_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f64 -> f16`.
pub fn baracuda_kernels_cast_f64_f16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_f64_f16`.
pub fn baracuda_kernels_cast_f64_f16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f64 -> bf16`.
pub fn baracuda_kernels_cast_f64_bf16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_f64_bf16`.
pub fn baracuda_kernels_cast_f64_bf16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f64 -> i32`.
pub fn baracuda_kernels_cast_f64_i32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_f64_i32`.
pub fn baracuda_kernels_cast_f64_i32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f64 -> i64`.
pub fn baracuda_kernels_cast_f64_i64_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_f64_i64`.
pub fn baracuda_kernels_cast_f64_i64_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f64 -> u8`.
pub fn baracuda_kernels_cast_f64_u8_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_f64_u8`.
pub fn baracuda_kernels_cast_f64_u8_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f64 -> i8`.
pub fn baracuda_kernels_cast_f64_i8_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_f64_i8`.
pub fn baracuda_kernels_cast_f64_i8_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
// f16 -> *
/// Cast `f16 -> f32`.
pub fn baracuda_kernels_cast_f16_f32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_f16_f32`.
pub fn baracuda_kernels_cast_f16_f32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f16 -> f64`.
pub fn baracuda_kernels_cast_f16_f64_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_f16_f64`.
pub fn baracuda_kernels_cast_f16_f64_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f16 -> f16`.
pub fn baracuda_kernels_cast_f16_f16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_f16_f16`.
pub fn baracuda_kernels_cast_f16_f16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f16 -> bf16`.
pub fn baracuda_kernels_cast_f16_bf16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_f16_bf16`.
pub fn baracuda_kernels_cast_f16_bf16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f16 -> i32`.
pub fn baracuda_kernels_cast_f16_i32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_f16_i32`.
pub fn baracuda_kernels_cast_f16_i32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f16 -> i64`.
pub fn baracuda_kernels_cast_f16_i64_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_f16_i64`.
pub fn baracuda_kernels_cast_f16_i64_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f16 -> u8`.
pub fn baracuda_kernels_cast_f16_u8_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_f16_u8`.
pub fn baracuda_kernels_cast_f16_u8_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f16 -> i8`.
pub fn baracuda_kernels_cast_f16_i8_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_f16_i8`.
pub fn baracuda_kernels_cast_f16_i8_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
// bf16 -> *
/// Cast `bf16 -> f32`.
pub fn baracuda_kernels_cast_bf16_f32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_bf16_f32`.
pub fn baracuda_kernels_cast_bf16_f32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `bf16 -> f64`.
pub fn baracuda_kernels_cast_bf16_f64_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_bf16_f64`.
pub fn baracuda_kernels_cast_bf16_f64_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `bf16 -> f16`.
pub fn baracuda_kernels_cast_bf16_f16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_bf16_f16`.
pub fn baracuda_kernels_cast_bf16_f16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `bf16 -> bf16`.
pub fn baracuda_kernels_cast_bf16_bf16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_bf16_bf16`.
pub fn baracuda_kernels_cast_bf16_bf16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `bf16 -> i32`.
pub fn baracuda_kernels_cast_bf16_i32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_bf16_i32`.
pub fn baracuda_kernels_cast_bf16_i32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `bf16 -> i64`.
pub fn baracuda_kernels_cast_bf16_i64_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_bf16_i64`.
pub fn baracuda_kernels_cast_bf16_i64_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `bf16 -> u8`.
pub fn baracuda_kernels_cast_bf16_u8_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_bf16_u8`.
pub fn baracuda_kernels_cast_bf16_u8_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `bf16 -> i8`.
pub fn baracuda_kernels_cast_bf16_i8_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_bf16_i8`.
pub fn baracuda_kernels_cast_bf16_i8_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
// i32 -> *
/// Cast `i32 -> f32`.
pub fn baracuda_kernels_cast_i32_f32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_i32_f32`.
pub fn baracuda_kernels_cast_i32_f32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i32 -> f64`.
pub fn baracuda_kernels_cast_i32_f64_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_i32_f64`.
pub fn baracuda_kernels_cast_i32_f64_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i32 -> f16`.
pub fn baracuda_kernels_cast_i32_f16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_i32_f16`.
pub fn baracuda_kernels_cast_i32_f16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i32 -> bf16`.
pub fn baracuda_kernels_cast_i32_bf16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_i32_bf16`.
pub fn baracuda_kernels_cast_i32_bf16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i32 -> i32`.
pub fn baracuda_kernels_cast_i32_i32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_i32_i32`.
pub fn baracuda_kernels_cast_i32_i32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i32 -> i64`.
pub fn baracuda_kernels_cast_i32_i64_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_i32_i64`.
pub fn baracuda_kernels_cast_i32_i64_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i32 -> u8`.
pub fn baracuda_kernels_cast_i32_u8_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_i32_u8`.
pub fn baracuda_kernels_cast_i32_u8_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i32 -> i8`.
pub fn baracuda_kernels_cast_i32_i8_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_i32_i8`.
pub fn baracuda_kernels_cast_i32_i8_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
// i64 -> *
/// Cast `i64 -> f32`.
pub fn baracuda_kernels_cast_i64_f32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_i64_f32`.
pub fn baracuda_kernels_cast_i64_f32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i64 -> f64`.
pub fn baracuda_kernels_cast_i64_f64_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_i64_f64`.
pub fn baracuda_kernels_cast_i64_f64_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i64 -> f16`.
pub fn baracuda_kernels_cast_i64_f16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_i64_f16`.
pub fn baracuda_kernels_cast_i64_f16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i64 -> bf16`.
pub fn baracuda_kernels_cast_i64_bf16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_i64_bf16`.
pub fn baracuda_kernels_cast_i64_bf16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i64 -> i32`.
pub fn baracuda_kernels_cast_i64_i32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_i64_i32`.
pub fn baracuda_kernels_cast_i64_i32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i64 -> i64`.
pub fn baracuda_kernels_cast_i64_i64_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_i64_i64`.
pub fn baracuda_kernels_cast_i64_i64_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i64 -> u8`.
pub fn baracuda_kernels_cast_i64_u8_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_i64_u8`.
pub fn baracuda_kernels_cast_i64_u8_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i64 -> i8`.
pub fn baracuda_kernels_cast_i64_i8_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_i64_i8`.
pub fn baracuda_kernels_cast_i64_i8_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
// u8 -> *
/// Cast `u8 -> f32`.
pub fn baracuda_kernels_cast_u8_f32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_u8_f32`.
pub fn baracuda_kernels_cast_u8_f32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `u8 -> f64`.
pub fn baracuda_kernels_cast_u8_f64_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_u8_f64`.
pub fn baracuda_kernels_cast_u8_f64_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `u8 -> f16`.
pub fn baracuda_kernels_cast_u8_f16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_u8_f16`.
pub fn baracuda_kernels_cast_u8_f16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `u8 -> bf16`.
pub fn baracuda_kernels_cast_u8_bf16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_u8_bf16`.
pub fn baracuda_kernels_cast_u8_bf16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `u8 -> i32`.
pub fn baracuda_kernels_cast_u8_i32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_u8_i32`.
pub fn baracuda_kernels_cast_u8_i32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `u8 -> i64`.
pub fn baracuda_kernels_cast_u8_i64_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_u8_i64`.
pub fn baracuda_kernels_cast_u8_i64_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `u8 -> u8`.
pub fn baracuda_kernels_cast_u8_u8_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_u8_u8`.
pub fn baracuda_kernels_cast_u8_u8_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `u8 -> i8`.
pub fn baracuda_kernels_cast_u8_i8_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_u8_i8`.
pub fn baracuda_kernels_cast_u8_i8_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
// i8 -> *
/// Cast `i8 -> f32`.
pub fn baracuda_kernels_cast_i8_f32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_i8_f32`.
pub fn baracuda_kernels_cast_i8_f32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i8 -> f64`.
pub fn baracuda_kernels_cast_i8_f64_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_i8_f64`.
pub fn baracuda_kernels_cast_i8_f64_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i8 -> f16`.
pub fn baracuda_kernels_cast_i8_f16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_i8_f16`.
pub fn baracuda_kernels_cast_i8_f16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i8 -> bf16`.
pub fn baracuda_kernels_cast_i8_bf16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_i8_bf16`.
pub fn baracuda_kernels_cast_i8_bf16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i8 -> i32`.
pub fn baracuda_kernels_cast_i8_i32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_i8_i32`.
pub fn baracuda_kernels_cast_i8_i32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i8 -> i64`.
pub fn baracuda_kernels_cast_i8_i64_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_i8_i64`.
pub fn baracuda_kernels_cast_i8_i64_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i8 -> u8`.
pub fn baracuda_kernels_cast_i8_u8_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_i8_u8`.
pub fn baracuda_kernels_cast_i8_u8_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i8 -> i8`.
pub fn baracuda_kernels_cast_i8_i8_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_i8_i8`.
pub fn baracuda_kernels_cast_i8_i8_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
// ============================================================================
// Phase 31 — u32 + i16 cast matrix extensions (Fuel Phase 6c.2
// storage.rs unblock). Brings the regular-cast dtype matrix from
// 8×8 (f32/f64/f16/bf16/i32/i64/u8/i8) to 10×10 by adding u32 +
// i16 as both source and destination. 36 net new cells = 18 "to
// {u32, i16}" + 18 "from {u32, i16}" + 0 duplicate counting (the
// 2 u32↔i16 cross pairs land once in each block below).
//
// Each pair uses the same `(numel, x, y, ws, ws_b, stream) -> i32`
// ABI as the rest of the cast family. Truncation semantics match
// C++ `static_cast<TOut>(x)` — wraparound for narrowing integer
// casts, banker's-round-toward-zero for FP→integer.
// ============================================================================
// ----- f32/f64/f16/bf16/i32/i64/u8/i8 -> u32 -----------------------------
/// Cast `f32 -> u32`. Negative inputs are undefined per C++ rules
/// (typical NVCC behaviour: saturates toward 0). Phase 31.
pub fn baracuda_kernels_cast_f32_u32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_f32_u32_can_implement` (baracuda kernels cast f32 u32 can implement).
pub fn baracuda_kernels_cast_f32_u32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f64 -> u32`. Phase 31.
pub fn baracuda_kernels_cast_f64_u32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_f64_u32_can_implement` (baracuda kernels cast f64 u32 can implement).
pub fn baracuda_kernels_cast_f64_u32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f16 -> u32`. Phase 31.
pub fn baracuda_kernels_cast_f16_u32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_f16_u32_can_implement` (baracuda kernels cast f16 u32 can implement).
pub fn baracuda_kernels_cast_f16_u32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `bf16 -> u32`. Phase 31.
pub fn baracuda_kernels_cast_bf16_u32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_bf16_u32_can_implement` (baracuda kernels cast bf16 u32 can implement).
pub fn baracuda_kernels_cast_bf16_u32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i32 -> u32`. Bitwise reinterpret for the common case
/// (`x >= 0`); two's-complement wraparound otherwise. Phase 31.
pub fn baracuda_kernels_cast_i32_u32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_i32_u32_can_implement` (baracuda kernels cast i32 u32 can implement).
pub fn baracuda_kernels_cast_i32_u32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i64 -> u32`. Truncates the top 32 bits. Phase 31.
pub fn baracuda_kernels_cast_i64_u32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_i64_u32_can_implement` (baracuda kernels cast i64 u32 can implement).
pub fn baracuda_kernels_cast_i64_u32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `u8 -> u32`. Zero-extends. Phase 31.
pub fn baracuda_kernels_cast_u8_u32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_u8_u32_can_implement` (baracuda kernels cast u8 u32 can implement).
pub fn baracuda_kernels_cast_u8_u32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i8 -> u32`. Sign-extends then reinterprets. Phase 31.
pub fn baracuda_kernels_cast_i8_u32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_i8_u32_can_implement` (baracuda kernels cast i8 u32 can implement).
pub fn baracuda_kernels_cast_i8_u32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
// ----- f32/f64/f16/bf16/i32/i64/u8/i8 -> i16 -----------------------------
/// Cast `f32 -> i16`. Phase 31.
pub fn baracuda_kernels_cast_f32_i16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_f32_i16_can_implement` (baracuda kernels cast f32 i16 can implement).
pub fn baracuda_kernels_cast_f32_i16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f64 -> i16`. Phase 31.
pub fn baracuda_kernels_cast_f64_i16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_f64_i16_can_implement` (baracuda kernels cast f64 i16 can implement).
pub fn baracuda_kernels_cast_f64_i16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f16 -> i16`. Phase 31.
pub fn baracuda_kernels_cast_f16_i16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_f16_i16_can_implement` (baracuda kernels cast f16 i16 can implement).
pub fn baracuda_kernels_cast_f16_i16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `bf16 -> i16`. Phase 31.
pub fn baracuda_kernels_cast_bf16_i16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_bf16_i16_can_implement` (baracuda kernels cast bf16 i16 can implement).
pub fn baracuda_kernels_cast_bf16_i16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i32 -> i16`. Truncates to low 16 bits. Phase 31.
pub fn baracuda_kernels_cast_i32_i16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_i32_i16_can_implement` (baracuda kernels cast i32 i16 can implement).
pub fn baracuda_kernels_cast_i32_i16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i64 -> i16`. Truncates to low 16 bits. Phase 31.
pub fn baracuda_kernels_cast_i64_i16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_i64_i16_can_implement` (baracuda kernels cast i64 i16 can implement).
pub fn baracuda_kernels_cast_i64_i16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `u8 -> i16`. Zero-extends. Phase 31.
pub fn baracuda_kernels_cast_u8_i16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_u8_i16_can_implement` (baracuda kernels cast u8 i16 can implement).
pub fn baracuda_kernels_cast_u8_i16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i8 -> i16`. Sign-extends. Phase 31.
pub fn baracuda_kernels_cast_i8_i16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_i8_i16_can_implement` (baracuda kernels cast i8 i16 can implement).
pub fn baracuda_kernels_cast_i8_i16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
// ----- u32 -> * ----------------------------------------------------------
/// Cast `u32 -> f32`. Phase 31.
pub fn baracuda_kernels_cast_u32_f32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_u32_f32_can_implement` (baracuda kernels cast u32 f32 can implement).
pub fn baracuda_kernels_cast_u32_f32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `u32 -> f64`. Phase 31.
pub fn baracuda_kernels_cast_u32_f64_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_u32_f64_can_implement` (baracuda kernels cast u32 f64 can implement).
pub fn baracuda_kernels_cast_u32_f64_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `u32 -> f16`. Phase 31.
pub fn baracuda_kernels_cast_u32_f16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_u32_f16_can_implement` (baracuda kernels cast u32 f16 can implement).
pub fn baracuda_kernels_cast_u32_f16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `u32 -> bf16`. Phase 31.
pub fn baracuda_kernels_cast_u32_bf16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_u32_bf16_can_implement` (baracuda kernels cast u32 bf16 can implement).
pub fn baracuda_kernels_cast_u32_bf16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `u32 -> i32`. Bitwise reinterpret. Phase 31.
pub fn baracuda_kernels_cast_u32_i32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_u32_i32_can_implement` (baracuda kernels cast u32 i32 can implement).
pub fn baracuda_kernels_cast_u32_i32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `u32 -> i64`. Zero-extends. Phase 31.
pub fn baracuda_kernels_cast_u32_i64_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_u32_i64_can_implement` (baracuda kernels cast u32 i64 can implement).
pub fn baracuda_kernels_cast_u32_i64_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `u32 -> u8`. Truncates to low byte. Phase 31.
pub fn baracuda_kernels_cast_u32_u8_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_u32_u8_can_implement` (baracuda kernels cast u32 u8 can implement).
pub fn baracuda_kernels_cast_u32_u8_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `u32 -> i8`. Truncates to low byte then reinterprets. Phase 31.
pub fn baracuda_kernels_cast_u32_i8_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_u32_i8_can_implement` (baracuda kernels cast u32 i8 can implement).
pub fn baracuda_kernels_cast_u32_i8_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `u32 -> u32` (identity). Phase 31.
pub fn baracuda_kernels_cast_u32_u32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_u32_u32_can_implement` (baracuda kernels cast u32 u32 can implement).
pub fn baracuda_kernels_cast_u32_u32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `u32 -> i16`. Truncates to low 16 bits then reinterprets. Phase 31.
pub fn baracuda_kernels_cast_u32_i16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_u32_i16_can_implement` (baracuda kernels cast u32 i16 can implement).
pub fn baracuda_kernels_cast_u32_i16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
// ----- i16 -> * ----------------------------------------------------------
/// Cast `i16 -> f32`. Phase 31.
pub fn baracuda_kernels_cast_i16_f32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_i16_f32_can_implement` (baracuda kernels cast i16 f32 can implement).
pub fn baracuda_kernels_cast_i16_f32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i16 -> f64`. Phase 31.
pub fn baracuda_kernels_cast_i16_f64_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_i16_f64_can_implement` (baracuda kernels cast i16 f64 can implement).
pub fn baracuda_kernels_cast_i16_f64_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i16 -> f16`. Phase 31.
pub fn baracuda_kernels_cast_i16_f16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_i16_f16_can_implement` (baracuda kernels cast i16 f16 can implement).
pub fn baracuda_kernels_cast_i16_f16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i16 -> bf16`. Phase 31.
pub fn baracuda_kernels_cast_i16_bf16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_i16_bf16_can_implement` (baracuda kernels cast i16 bf16 can implement).
pub fn baracuda_kernels_cast_i16_bf16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i16 -> i32`. Sign-extends. Phase 31.
pub fn baracuda_kernels_cast_i16_i32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_i16_i32_can_implement` (baracuda kernels cast i16 i32 can implement).
pub fn baracuda_kernels_cast_i16_i32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i16 -> i64`. Sign-extends. Phase 31.
pub fn baracuda_kernels_cast_i16_i64_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_i16_i64_can_implement` (baracuda kernels cast i16 i64 can implement).
pub fn baracuda_kernels_cast_i16_i64_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i16 -> u8`. Truncates to low byte then reinterprets. Phase 31.
pub fn baracuda_kernels_cast_i16_u8_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_i16_u8_can_implement` (baracuda kernels cast i16 u8 can implement).
pub fn baracuda_kernels_cast_i16_u8_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i16 -> i8`. Truncates to low byte. Phase 31.
pub fn baracuda_kernels_cast_i16_i8_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_i16_i8_can_implement` (baracuda kernels cast i16 i8 can implement).
pub fn baracuda_kernels_cast_i16_i8_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i16 -> u32`. Sign-extends to i32 then reinterprets. Phase 31.
pub fn baracuda_kernels_cast_i16_u32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_i16_u32_can_implement` (baracuda kernels cast i16 u32 can implement).
pub fn baracuda_kernels_cast_i16_u32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i16 -> i16` (identity). Phase 31.
pub fn baracuda_kernels_cast_i16_i16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_i16_i16_can_implement` (baracuda kernels cast i16 i16 can implement).
pub fn baracuda_kernels_cast_i16_i16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
// ============================================================================
// Phase 13.3 — sub-byte cast paths (Bool / Fp8 / S4 / U4).
//
// These cover dtypes that the classic `baracuda_cast.cuh` doesn't
// wire because their conversion semantics differ from a plain
// `static_cast<TOut>(x)`:
//
// * Bool ↔ T: 0/non-zero truthiness normalization.
// * Fp8 ↔ {f32, f16, bf16}: routes through f32 via NVIDIA's
// `__nv_cvt_*_fp8` intrinsics with SATFINITE semantics.
// * S4 / U4 ↔ {i32, i64, f32}: packed-pair nibble storage. UNPACK
// (S4/U4 → wide) sign- or zero-extends; PACK (wide → S4/U4)
// saturates to [-8, 7] or [0, 15] before nibble-masking.
// `numel` is the element count and must be even; the packed
// buffer holds `numel / 2` bytes.
// ============================================================================
// ----- Bool -> { i32, i64, f32, f16, bf16 } ----------------------------
/// Cast `Bool -> i32`. `x != 0 → 1`.
pub fn baracuda_kernels_cast_bool_i32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_bool_i32`.
pub fn baracuda_kernels_cast_bool_i32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `Bool -> i64`.
pub fn baracuda_kernels_cast_bool_i64_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_bool_i64`.
pub fn baracuda_kernels_cast_bool_i64_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `Bool -> f32`.
pub fn baracuda_kernels_cast_bool_f32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_bool_f32`.
pub fn baracuda_kernels_cast_bool_f32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `Bool -> f16`.
pub fn baracuda_kernels_cast_bool_f16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_bool_f16`.
pub fn baracuda_kernels_cast_bool_f16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `Bool -> bf16`.
pub fn baracuda_kernels_cast_bool_bf16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_bool_bf16`.
pub fn baracuda_kernels_cast_bool_bf16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
// ----- { i32, i64, f32, f16, bf16 } -> Bool ----------------------------
/// Cast `i32 -> Bool`. `x != 0 → 1`.
pub fn baracuda_kernels_cast_i32_bool_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_i32_bool`.
pub fn baracuda_kernels_cast_i32_bool_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i64 -> Bool`.
pub fn baracuda_kernels_cast_i64_bool_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_i64_bool`.
pub fn baracuda_kernels_cast_i64_bool_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f32 -> Bool`.
pub fn baracuda_kernels_cast_f32_bool_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_f32_bool`.
pub fn baracuda_kernels_cast_f32_bool_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f16 -> Bool`.
pub fn baracuda_kernels_cast_f16_bool_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_f16_bool`.
pub fn baracuda_kernels_cast_f16_bool_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `bf16 -> Bool`.
pub fn baracuda_kernels_cast_bf16_bool_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// Implementability check for `cast_bf16_bool`.
pub fn baracuda_kernels_cast_bf16_bool_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
// ----- Fp8E4M3 ↔ { f32, f16, bf16 } ------------------------------------
/// Cast `Fp8E4M3 -> f32`.
pub fn baracuda_kernels_cast_fp8e4m3_f32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_fp8e4m3_f32_can_implement` (baracuda kernels cast fp8e4m3 f32 can implement).
pub fn baracuda_kernels_cast_fp8e4m3_f32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `Fp8E4M3 -> f16`.
pub fn baracuda_kernels_cast_fp8e4m3_f16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_fp8e4m3_f16_can_implement` (baracuda kernels cast fp8e4m3 f16 can implement).
pub fn baracuda_kernels_cast_fp8e4m3_f16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `Fp8E4M3 -> bf16`.
pub fn baracuda_kernels_cast_fp8e4m3_bf16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_fp8e4m3_bf16_can_implement` (baracuda kernels cast fp8e4m3 bf16 can implement).
pub fn baracuda_kernels_cast_fp8e4m3_bf16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f32 -> Fp8E4M3` (saturates to ±448).
pub fn baracuda_kernels_cast_f32_fp8e4m3_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_f32_fp8e4m3_can_implement` (baracuda kernels cast f32 fp8e4m3 can implement).
pub fn baracuda_kernels_cast_f32_fp8e4m3_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f16 -> Fp8E4M3`.
pub fn baracuda_kernels_cast_f16_fp8e4m3_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_f16_fp8e4m3_can_implement` (baracuda kernels cast f16 fp8e4m3 can implement).
pub fn baracuda_kernels_cast_f16_fp8e4m3_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `bf16 -> Fp8E4M3`.
pub fn baracuda_kernels_cast_bf16_fp8e4m3_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_bf16_fp8e4m3_can_implement` (baracuda kernels cast bf16 fp8e4m3 can implement).
pub fn baracuda_kernels_cast_bf16_fp8e4m3_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
// ----- Fp8E5M2 ↔ { f32, f16, bf16 } ------------------------------------
/// Cast `Fp8E5M2 -> f32`.
pub fn baracuda_kernels_cast_fp8e5m2_f32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_fp8e5m2_f32_can_implement` (baracuda kernels cast fp8e5m2 f32 can implement).
pub fn baracuda_kernels_cast_fp8e5m2_f32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `Fp8E5M2 -> f16`.
pub fn baracuda_kernels_cast_fp8e5m2_f16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_fp8e5m2_f16_can_implement` (baracuda kernels cast fp8e5m2 f16 can implement).
pub fn baracuda_kernels_cast_fp8e5m2_f16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `Fp8E5M2 -> bf16`.
pub fn baracuda_kernels_cast_fp8e5m2_bf16_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_fp8e5m2_bf16_can_implement` (baracuda kernels cast fp8e5m2 bf16 can implement).
pub fn baracuda_kernels_cast_fp8e5m2_bf16_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f32 -> Fp8E5M2` (saturates to ±57344).
pub fn baracuda_kernels_cast_f32_fp8e5m2_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_f32_fp8e5m2_can_implement` (baracuda kernels cast f32 fp8e5m2 can implement).
pub fn baracuda_kernels_cast_f32_fp8e5m2_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f16 -> Fp8E5M2`.
pub fn baracuda_kernels_cast_f16_fp8e5m2_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_f16_fp8e5m2_can_implement` (baracuda kernels cast f16 fp8e5m2 can implement).
pub fn baracuda_kernels_cast_f16_fp8e5m2_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `bf16 -> Fp8E5M2`.
pub fn baracuda_kernels_cast_bf16_fp8e5m2_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_bf16_fp8e5m2_can_implement` (baracuda kernels cast bf16 fp8e5m2 can implement).
pub fn baracuda_kernels_cast_bf16_fp8e5m2_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
// ----- S4 ↔ { i32, i64, f32 } — packed nibble ------------------------
// UNPACK (S4 → wide): `numel` must be even; sign-extends each nibble.
/// Cast `S4 -> i32` (unpack: sign-extend nibble to int32).
pub fn baracuda_kernels_cast_s4_i32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_s4_i32_can_implement` (baracuda kernels cast s4 i32 can implement).
pub fn baracuda_kernels_cast_s4_i32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `S4 -> i64`.
pub fn baracuda_kernels_cast_s4_i64_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_s4_i64_can_implement` (baracuda kernels cast s4 i64 can implement).
pub fn baracuda_kernels_cast_s4_i64_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `S4 -> f32`.
pub fn baracuda_kernels_cast_s4_f32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_s4_f32_can_implement` (baracuda kernels cast s4 f32 can implement).
pub fn baracuda_kernels_cast_s4_f32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
// PACK (wide → S4): `numel` must be even; saturates inputs to [-8, +7].
/// Cast `i32 -> S4` (pack: saturate to [-8, +7] then nibble-mask).
pub fn baracuda_kernels_cast_i32_s4_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_i32_s4_can_implement` (baracuda kernels cast i32 s4 can implement).
pub fn baracuda_kernels_cast_i32_s4_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i64 -> S4`.
pub fn baracuda_kernels_cast_i64_s4_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_i64_s4_can_implement` (baracuda kernels cast i64 s4 can implement).
pub fn baracuda_kernels_cast_i64_s4_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f32 -> S4` (round-to-nearest then saturate).
pub fn baracuda_kernels_cast_f32_s4_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_f32_s4_can_implement` (baracuda kernels cast f32 s4 can implement).
pub fn baracuda_kernels_cast_f32_s4_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
// ----- U4 ↔ { i32, i64, f32 } — packed nibble ------------------------
/// Cast `U4 -> i32` (unpack: zero-extend nibble to int32).
pub fn baracuda_kernels_cast_u4_i32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_u4_i32_can_implement` (baracuda kernels cast u4 i32 can implement).
pub fn baracuda_kernels_cast_u4_i32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `U4 -> i64`.
pub fn baracuda_kernels_cast_u4_i64_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_u4_i64_can_implement` (baracuda kernels cast u4 i64 can implement).
pub fn baracuda_kernels_cast_u4_i64_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `U4 -> f32`.
pub fn baracuda_kernels_cast_u4_f32_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_u4_f32_can_implement` (baracuda kernels cast u4 f32 can implement).
pub fn baracuda_kernels_cast_u4_f32_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i32 -> U4` (pack: saturate to [0, 15] then nibble-mask).
pub fn baracuda_kernels_cast_i32_u4_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_i32_u4_can_implement` (baracuda kernels cast i32 u4 can implement).
pub fn baracuda_kernels_cast_i32_u4_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `i64 -> U4`.
pub fn baracuda_kernels_cast_i64_u4_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_i64_u4_can_implement` (baracuda kernels cast i64 u4 can implement).
pub fn baracuda_kernels_cast_i64_u4_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
/// Cast `f32 -> U4` (round-to-nearest then saturate).
pub fn baracuda_kernels_cast_f32_u4_run(numel: i64, x: *const c_void, y: *mut c_void, workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void) -> i32;
/// `baracuda_kernels_cast_f32_u4_can_implement` (baracuda kernels cast f32 u4 can implement).
pub fn baracuda_kernels_cast_f32_u4_can_implement(numel: i64, x: *const c_void, y: *const c_void) -> i32;
// ----- Fill ------------------------------------------------------------
/// Fill `y` with `value`, f32 dtype. This is the fill trailblazer —
/// every `fill_<dt>_run` (and `_strided_run`) variant follows the
/// same write-only contract.
///
/// # Safety
/// `y` must point to at least `numel * sizeof::<f32>()` bytes of
/// device memory. `stream` must be a live CUDA stream.
///
/// **Aliasing (Phase 64)**: fill is write-only — there are no
/// input device pointers to alias against. The kernel body is
/// trivially `y[i] = value` (write-only per thread, no reads).
/// Trivially in-place; no aliasing concerns.
pub fn baracuda_kernels_fill_f32_run(
numel: i64,
y: *mut c_void,
value: f32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `fill_f32`. Host-side only.
pub fn baracuda_kernels_fill_f32_can_implement(numel: i64, y: *const c_void) -> i32;
/// Fill `y` with `value`, f64 dtype.
///
/// # Safety
/// Same contract as the f32 variant; storage is f64.
pub fn baracuda_kernels_fill_f64_run(
numel: i64,
y: *mut c_void,
value: f64,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `fill_f64`. Host-side only.
pub fn baracuda_kernels_fill_f64_can_implement(numel: i64, y: *const c_void) -> i32;
/// Fill `y` with `value`, i32 dtype.
///
/// # Safety
/// Same contract as the f32 variant; storage is i32.
pub fn baracuda_kernels_fill_i32_run(
numel: i64,
y: *mut c_void,
value: i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `fill_i32`. Host-side only.
pub fn baracuda_kernels_fill_i32_can_implement(numel: i64, y: *const c_void) -> i32;
/// Fill `y` with `value`, i64 dtype.
///
/// # Safety
/// Same contract as the f32 variant; storage is i64.
pub fn baracuda_kernels_fill_i64_run(
numel: i64,
y: *mut c_void,
value: i64,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `fill_i64`. Host-side only.
pub fn baracuda_kernels_fill_i64_can_implement(numel: i64, y: *const c_void) -> i32;
/// Fill `y` with `value`, u8 dtype.
///
/// # Safety
/// Same contract as the f32 variant; storage is u8.
pub fn baracuda_kernels_fill_u8_run(
numel: i64,
y: *mut c_void,
value: u8,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `fill_u8`. Host-side only.
pub fn baracuda_kernels_fill_u8_can_implement(numel: i64, y: *const c_void) -> i32;
/// Fill `y` with `value`, i8 dtype.
///
/// # Safety
/// Same contract as the f32 variant; storage is i8.
pub fn baracuda_kernels_fill_i8_run(
numel: i64,
y: *mut c_void,
value: i8,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `fill_i8`. Host-side only.
pub fn baracuda_kernels_fill_i8_can_implement(numel: i64, y: *const c_void) -> i32;
/// Fill `y` with `value`, f16 dtype. `value_bits` is the raw
/// 16-bit pattern of an `f16` value (transport convention shared
/// with the Pad-constant family).
///
/// # Safety
/// `y` must point to at least `numel * 2` bytes of device memory.
pub fn baracuda_kernels_fill_f16_run(
numel: i64,
y: *mut c_void,
value_bits: u16,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `fill_f16`. Host-side only.
pub fn baracuda_kernels_fill_f16_can_implement(numel: i64, y: *const c_void) -> i32;
/// Fill `y` with `value`, bf16 dtype. `value_bits` is the raw
/// 16-bit pattern of a `bf16` value.
///
/// # Safety
/// `y` must point to at least `numel * 2` bytes of device memory.
pub fn baracuda_kernels_fill_bf16_run(
numel: i64,
y: *mut c_void,
value_bits: u16,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `fill_bf16`. Host-side only.
pub fn baracuda_kernels_fill_bf16_can_implement(numel: i64, y: *const c_void) -> i32;
// ----- Phase 36 (Fuel ask Gap 4) — additional dtypes -----
//
// Contig fill for u32, i16, and FP8 E4M3 (raw u8 storage). Follow
// the same `T value` ABI as the existing contig fill family; FP8
// is transported as raw `u8` since the storage type is byte-
// identical to `uint8_t`.
/// Fill `y` with `value`, u32 dtype.
pub fn baracuda_kernels_fill_u32_run(
numel: i64,
y: *mut c_void,
value: u32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fill_u32_can_implement` (baracuda kernels fill u32 can implement).
pub fn baracuda_kernels_fill_u32_can_implement(numel: i64, y: *const c_void) -> i32;
/// Fill `y` with `value`, i16 dtype.
pub fn baracuda_kernels_fill_i16_run(
numel: i64,
y: *mut c_void,
value: i16,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fill_i16_can_implement` (baracuda kernels fill i16 can implement).
pub fn baracuda_kernels_fill_i16_can_implement(numel: i64, y: *const c_void) -> i32;
/// Fill `y` with `value`, FP8 E4M3 dtype. `value` is the raw 8-bit
/// E4M3 encoding (storage is byte-identical to `u8`); callers
/// compute the encoding via the cast family or `__nv_cvt_float_to_fp8`.
pub fn baracuda_kernels_fill_fp8e4m3_run(
numel: i64,
y: *mut c_void,
value: u8,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fill_fp8e4m3_can_implement` (baracuda kernels fill fp8e4m3 can implement).
pub fn baracuda_kernels_fill_fp8e4m3_can_implement(
numel: i64, y: *const c_void,
) -> i32;
// ----- Phase 36 (Fuel ask Gap 4) — strided fill -----
//
// Strided variants for all 11 dtypes (existing 8 + 3 new). The
// logical output is `y[lin(coord)] = value` where `coord` iterates
// row-major over `shape[0..rank]` and `lin(coord) = Σ coord[axis]
// * stride_y[axis]`. `numel` must equal `Π shape[d]`. Rank up to
// `MAX_RANK = 8` (matches affine.cuh's `MAX_RANK`). `shape` and
// `stride_y` are HOST-side arrays (copied into a kernel param
// block).
//
// Strides are signed `i64` (negative-stride / broadcast-stride
// supported). f16 / bf16 transport `value` as a raw `u16` bit
// pattern; FP8 E4M3 transports raw `u8`; other dtypes pass `value`
// by their natural type.
/// `baracuda_kernels_fill_f32_strided_run` (baracuda kernels fill f32 strided run).
pub fn baracuda_kernels_fill_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_y: *const i64,
value: f32,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fill_f32_strided_can_implement` (baracuda kernels fill f32 strided can implement).
pub fn baracuda_kernels_fill_f32_strided_can_implement(numel: i64, rank: i32) -> i32;
/// `baracuda_kernels_fill_f64_strided_run` (baracuda kernels fill f64 strided run).
pub fn baracuda_kernels_fill_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_y: *const i64,
value: f64,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fill_f64_strided_can_implement` (baracuda kernels fill f64 strided can implement).
pub fn baracuda_kernels_fill_f64_strided_can_implement(numel: i64, rank: i32) -> i32;
/// `baracuda_kernels_fill_i32_strided_run` (baracuda kernels fill i32 strided run).
pub fn baracuda_kernels_fill_i32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_y: *const i64,
value: i32,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fill_i32_strided_can_implement` (baracuda kernels fill i32 strided can implement).
pub fn baracuda_kernels_fill_i32_strided_can_implement(numel: i64, rank: i32) -> i32;
/// `baracuda_kernels_fill_i64_strided_run` (baracuda kernels fill i64 strided run).
pub fn baracuda_kernels_fill_i64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_y: *const i64,
value: i64,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fill_i64_strided_can_implement` (baracuda kernels fill i64 strided can implement).
pub fn baracuda_kernels_fill_i64_strided_can_implement(numel: i64, rank: i32) -> i32;
/// `baracuda_kernels_fill_u8_strided_run` (baracuda kernels fill u8 strided run).
pub fn baracuda_kernels_fill_u8_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_y: *const i64,
value: u8,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fill_u8_strided_can_implement` (baracuda kernels fill u8 strided can implement).
pub fn baracuda_kernels_fill_u8_strided_can_implement(numel: i64, rank: i32) -> i32;
/// `baracuda_kernels_fill_i8_strided_run` (baracuda kernels fill i8 strided run).
pub fn baracuda_kernels_fill_i8_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_y: *const i64,
value: i8,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fill_i8_strided_can_implement` (baracuda kernels fill i8 strided can implement).
pub fn baracuda_kernels_fill_i8_strided_can_implement(numel: i64, rank: i32) -> i32;
/// `baracuda_kernels_fill_u32_strided_run` (baracuda kernels fill u32 strided run).
pub fn baracuda_kernels_fill_u32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_y: *const i64,
value: u32,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fill_u32_strided_can_implement` (baracuda kernels fill u32 strided can implement).
pub fn baracuda_kernels_fill_u32_strided_can_implement(numel: i64, rank: i32) -> i32;
/// `baracuda_kernels_fill_i16_strided_run` (baracuda kernels fill i16 strided run).
pub fn baracuda_kernels_fill_i16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_y: *const i64,
value: i16,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fill_i16_strided_can_implement` (baracuda kernels fill i16 strided can implement).
pub fn baracuda_kernels_fill_i16_strided_can_implement(numel: i64, rank: i32) -> i32;
/// `baracuda_kernels_fill_fp8e4m3_strided_run` (baracuda kernels fill fp8e4m3 strided run).
pub fn baracuda_kernels_fill_fp8e4m3_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_y: *const i64,
value: u8,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fill_fp8e4m3_strided_can_implement` (baracuda kernels fill fp8e4m3 strided can implement).
pub fn baracuda_kernels_fill_fp8e4m3_strided_can_implement(numel: i64, rank: i32) -> i32;
/// Strided fill, f16. `value_bits` is the raw 16-bit pattern of an `f16` value.
pub fn baracuda_kernels_fill_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_y: *const i64,
value_bits: u16,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fill_f16_strided_can_implement` (baracuda kernels fill f16 strided can implement).
pub fn baracuda_kernels_fill_f16_strided_can_implement(numel: i64, rank: i32) -> i32;
/// Strided fill, bf16. `value_bits` is the raw 16-bit pattern of a `bf16` value.
pub fn baracuda_kernels_fill_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_y: *const i64,
value_bits: u16,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fill_bf16_strided_can_implement` (baracuda kernels fill bf16 strided can implement).
pub fn baracuda_kernels_fill_bf16_strided_can_implement(numel: i64, rank: i32) -> i32;
// ----- Affine ----------------------------------------------------------
//
// `y[i] = a * x[i] + b`. f16 / bf16 receive `a` / `b` as `f32` and
// compute through f32 internally (matches the elementwise family's
// f32-accumulator contract). The other dtypes receive `a` / `b` in
// the kernel's element type directly.
/// Affine `y = a*x + b`, f32 dtype.
///
/// # Safety
/// `x` and `y` must each point to at least `numel` `f32`s of device
/// memory. Aliasing `y` with `x` is safe.
pub fn baracuda_kernels_affine_f32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
a: f32,
b: f32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `affine_f32`. Host-side only.
pub fn baracuda_kernels_affine_f32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Affine `y = a*x + b`, f64 dtype.
///
/// # Safety
/// Same contract as the f32 variant; storage is f64.
pub fn baracuda_kernels_affine_f64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
a: f64,
b: f64,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `affine_f64`. Host-side only.
pub fn baracuda_kernels_affine_f64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Affine `y = a*x + b`, i32 dtype.
///
/// # Safety
/// Same contract as the f32 variant; storage is i32.
pub fn baracuda_kernels_affine_i32_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
a: i32,
b: i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `affine_i32`. Host-side only.
pub fn baracuda_kernels_affine_i32_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Affine `y = a*x + b`, i64 dtype.
///
/// # Safety
/// Same contract as the f32 variant; storage is i64.
pub fn baracuda_kernels_affine_i64_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
a: i64,
b: i64,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `affine_i64`. Host-side only.
pub fn baracuda_kernels_affine_i64_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Affine `y = a*x + b`, u8 dtype.
///
/// # Safety
/// Same contract as the f32 variant; storage is u8.
pub fn baracuda_kernels_affine_u8_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
a: u8,
b: u8,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `affine_u8`. Host-side only.
pub fn baracuda_kernels_affine_u8_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Affine `y = a*x + b`, i8 dtype.
///
/// # Safety
/// Same contract as the f32 variant; storage is i8.
pub fn baracuda_kernels_affine_i8_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
a: i8,
b: i8,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `affine_i8`. Host-side only.
pub fn baracuda_kernels_affine_i8_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Affine `y = a*x + b`, f16 storage / f32 compute. `a` / `b`
/// arrive as `f32`.
///
/// # Safety
/// `x` and `y` must each point to at least `numel * 2` bytes of
/// device memory holding `__half` values. Aliasing is safe.
pub fn baracuda_kernels_affine_f16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
a: f32,
b: f32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `affine_f16`. Host-side only.
pub fn baracuda_kernels_affine_f16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Affine `y = a*x + b`, bf16 storage / f32 compute. `a` / `b`
/// arrive as `f32`.
///
/// # Safety
/// `x` and `y` must each point to at least `numel * 2` bytes of
/// device memory holding `__nv_bfloat16` values. Aliasing is safe.
pub fn baracuda_kernels_affine_bf16_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
a: f32,
b: f32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `affine_bf16`. Host-side only.
pub fn baracuda_kernels_affine_bf16_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
// ------------------------------------------------------------------------
// Strided sibling — Phase 14.1.
//
// One thread per output element; thread decomposes its output
// linear index into a multi-coord and dots with the per-axis
// input / output strides (signed i64) to derive source / dest
// element offsets, then computes `y = a*x + b` at the same
// precision as the contig sibling.
//
// Common ABI:
// numel : total output element count (product of `shape`).
// rank : tensor rank, in `[0, 8]`.
// shape : `*const i32` of `rank` entries; logical extents.
// stride_x : `*const i64` of `rank` entries; SIGNED input strides
// (negative = flipped axis, zero = broadcast).
// stride_y : `*const i64` of `rank` entries; SIGNED output strides.
// x / y : device pointers to element-typed storage.
// a / b : scalar multiplier / bias. Same dtype as the contig
// sibling (f32 for f16/bf16 paths; T otherwise).
// workspace / workspace_bytes : unused (kept for ABI symmetry).
// stream : CUDA stream handle (cast to `cudaStream_t`).
//
// Status codes mirror the contig variant:
// 0 success, 2 invalid problem, 5 internal launch error.
/// Strided affine `y = a*x + b`, f32 dtype.
///
/// # Safety
/// `x` / `y` must hold at least `numel` f32 elements, indexable via
/// the per-element offsets implied by (`shape`, `stride_x` / `stride_y`).
/// Aliasing `y` with `x` is safe so long as no two output indices
/// resolve to the same `x` element (in-place over a strided view).
pub fn baracuda_kernels_affine_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
a: f32,
b: f32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_affine_f32_strided_can_implement` (baracuda kernels affine f32 strided can implement).
pub fn baracuda_kernels_affine_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
a: f32,
b: f32,
) -> i32;
/// Strided affine `y = a*x + b`, f64 dtype.
///
/// # Safety
/// Same as the f32-strided variant; storage is f64.
pub fn baracuda_kernels_affine_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
a: f64,
b: f64,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_affine_f64_strided_can_implement` (baracuda kernels affine f64 strided can implement).
pub fn baracuda_kernels_affine_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
a: f64,
b: f64,
) -> i32;
/// Strided affine `y = a*x + b`, i32 dtype.
///
/// # Safety
/// Same as the f32-strided variant; storage is i32.
pub fn baracuda_kernels_affine_i32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
a: i32,
b: i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_affine_i32_strided_can_implement` (baracuda kernels affine i32 strided can implement).
pub fn baracuda_kernels_affine_i32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
a: i32,
b: i32,
) -> i32;
/// Strided affine `y = a*x + b`, i64 dtype.
///
/// # Safety
/// Same as the f32-strided variant; storage is i64.
pub fn baracuda_kernels_affine_i64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
a: i64,
b: i64,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_affine_i64_strided_can_implement` (baracuda kernels affine i64 strided can implement).
pub fn baracuda_kernels_affine_i64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
a: i64,
b: i64,
) -> i32;
/// Strided affine `y = a*x + b`, u8 dtype.
///
/// # Safety
/// Same as the f32-strided variant; storage is u8. Wraps on overflow
/// (mod 256), matching the C `uint8_t` `*` / `+` operator semantics.
pub fn baracuda_kernels_affine_u8_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
a: u8,
b: u8,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_affine_u8_strided_can_implement` (baracuda kernels affine u8 strided can implement).
pub fn baracuda_kernels_affine_u8_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
a: u8,
b: u8,
) -> i32;
/// Strided affine `y = a*x + b`, f16 storage / f32 compute. `a` /
/// `b` arrive as `f32`.
///
/// # Safety
/// `x` / `y` must hold at least `numel` `__half` elements; index
/// pattern follows (`shape`, `stride_x` / `stride_y`).
pub fn baracuda_kernels_affine_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
a: f32,
b: f32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_affine_f16_strided_can_implement` (baracuda kernels affine f16 strided can implement).
pub fn baracuda_kernels_affine_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
a: f32,
b: f32,
) -> i32;
/// Strided affine `y = a*x + b`, bf16 storage / f32 compute. `a` /
/// `b` arrive as `f32`.
///
/// # Safety
/// `x` / `y` must hold at least `numel` `__nv_bfloat16` elements;
/// index pattern follows (`shape`, `stride_x` / `stride_y`).
pub fn baracuda_kernels_affine_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
a: f32,
b: f32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_affine_bf16_strided_can_implement` (baracuda kernels affine bf16 strided can implement).
pub fn baracuda_kernels_affine_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
a: f32,
b: f32,
) -> i32;
}
// ============================================================================
// GGUF block-format quantization — Phase 8 Milestone 8.4 (Category P)
// ============================================================================
//
// Vendored from llama.cpp via fuel-cuda-kernels. See
// `kernels/include/baracuda_gguf.cuh` for lineage notes.
//
// Symbols are split into two op families:
// * `baracuda_kernels_dequantize_<qtype>_run` — unpack a GGUF-packed
// weight matrix into a dense f32 tensor.
// * `baracuda_kernels_mmvq_<qtype>_run` — fused dequant + matmul-vec
// (FP-activation MMVQ): `out[r] = Σ_c W_q[r, c] · y[c]`. Single FP
// activation vector in, FP output vector out.
//
// Pointer ABI (both families):
// - `x` / `vx` is the GGUF-packed weight buffer (raw bytes; element
// stride = the block size of the qtype). For dequant: pointer is
// to a flat byte buffer storing `ceil(numel / block_size)` blocks.
// For MMVQ: pointer is to a `[nrows × packed_cols_bytes]` matrix
// laid out row-major (one block-row per qtype row).
// - Dequant output `y` is `float* ` device memory, length `numel`.
// - MMVQ activation `y` is `float* ` of length `ncols`. MMVQ output
// `dst` is `float*` of length `nrows`.
//
// `numel` (dequant) and `ncols` (MMVQ) MUST be divisible by the qtype's
// block size: 32 for the type-0/1 qtypes (Q4_0 / Q4_1 / Q5_0 / Q5_1 /
// Q8_0); 256 for the k-quants (Q2_K / Q3_K / Q4_K / Q5_K / Q6_K / Q8_K).
//
// Q8_K MMVQ — added by baracuda in Phase 11.4 as a bespoke kernel
// (not vendored). Upstream llama.cpp / Fuel ship only the dequant
// kernel and treat Q8_K as a CPU-side intermediate; baracuda exposes
// a fused MMVQ to avoid the 2× memory traffic of dequant-then-GEMV
// on the inference decode step.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
// ---- Dequantize: type-0/1 (32-element blocks) ----
/// GGUF `Q4_0` block-format dequantize → f32. `numel` must be a
/// multiple of 32. # Safety: device-resident `x`, `y`; valid stream.
pub fn baracuda_kernels_dequantize_q4_0_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_dequantize_q4_0_can_implement` (baracuda kernels dequantize q4 0 can implement).
pub fn baracuda_kernels_dequantize_q4_0_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// GGUF `Q4_1` dequantize → f32. # Safety: as `Q4_0`.
pub fn baracuda_kernels_dequantize_q4_1_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_dequantize_q4_1_can_implement` (baracuda kernels dequantize q4 1 can implement).
pub fn baracuda_kernels_dequantize_q4_1_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// GGUF `Q5_0` dequantize → f32. # Safety: as `Q4_0`.
pub fn baracuda_kernels_dequantize_q5_0_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_dequantize_q5_0_can_implement` (baracuda kernels dequantize q5 0 can implement).
pub fn baracuda_kernels_dequantize_q5_0_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// GGUF `Q5_1` dequantize → f32. # Safety: as `Q4_0`.
pub fn baracuda_kernels_dequantize_q5_1_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_dequantize_q5_1_can_implement` (baracuda kernels dequantize q5 1 can implement).
pub fn baracuda_kernels_dequantize_q5_1_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// GGUF `Q8_0` dequantize → f32. # Safety: as `Q4_0`.
pub fn baracuda_kernels_dequantize_q8_0_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_dequantize_q8_0_can_implement` (baracuda kernels dequantize q8 0 can implement).
pub fn baracuda_kernels_dequantize_q8_0_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
// ---- Dequantize: k-quants (256-element blocks) ----
/// GGUF `Q2_K` dequantize → f32. `numel` must be a multiple of 256.
/// # Safety: device-resident `x`, `y`; valid stream.
pub fn baracuda_kernels_dequantize_q2_K_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_dequantize_q2_K_can_implement` (baracuda kernels dequantize q2 k can implement).
pub fn baracuda_kernels_dequantize_q2_K_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// GGUF `Q3_K` dequantize → f32. # Safety: as `Q2_K`.
pub fn baracuda_kernels_dequantize_q3_K_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_dequantize_q3_K_can_implement` (baracuda kernels dequantize q3 k can implement).
pub fn baracuda_kernels_dequantize_q3_K_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// GGUF `Q4_K` dequantize → f32. # Safety: as `Q2_K`.
pub fn baracuda_kernels_dequantize_q4_K_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_dequantize_q4_K_can_implement` (baracuda kernels dequantize q4 k can implement).
pub fn baracuda_kernels_dequantize_q4_K_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// GGUF `Q5_K` dequantize → f32. # Safety: as `Q2_K`.
pub fn baracuda_kernels_dequantize_q5_K_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_dequantize_q5_K_can_implement` (baracuda kernels dequantize q5 k can implement).
pub fn baracuda_kernels_dequantize_q5_K_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// GGUF `Q6_K` dequantize → f32. # Safety: as `Q2_K`.
pub fn baracuda_kernels_dequantize_q6_K_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_dequantize_q6_K_can_implement` (baracuda kernels dequantize q6 k can implement).
pub fn baracuda_kernels_dequantize_q6_K_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// GGUF `Q8_K` dequantize → f32. # Safety: as `Q2_K`.
pub fn baracuda_kernels_dequantize_q8_K_run(
numel: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_dequantize_q8_K_can_implement` (baracuda kernels dequantize q8 k can implement).
pub fn baracuda_kernels_dequantize_q8_K_can_implement(
numel: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
// ---- MMVQ: type-0/1 (32-element blocks) ----
/// GGUF `Q4_0` MMVQ — FP-activation matrix-vector mul.
/// `ncols` must be a multiple of 32.
/// # Safety: device-resident `x` (packed), `y` (f32 activation),
/// `dst` (f32 output); valid stream.
pub fn baracuda_kernels_mmvq_q4_0_run(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q4_0_can_implement` (baracuda kernels mmvq q4 0 can implement).
pub fn baracuda_kernels_mmvq_q4_0_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// GGUF `Q4_1` MMVQ. # Safety: as `Q4_0`.
pub fn baracuda_kernels_mmvq_q4_1_run(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q4_1_can_implement` (baracuda kernels mmvq q4 1 can implement).
pub fn baracuda_kernels_mmvq_q4_1_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// GGUF `Q5_0` MMVQ. # Safety: as `Q4_0`.
pub fn baracuda_kernels_mmvq_q5_0_run(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q5_0_can_implement` (baracuda kernels mmvq q5 0 can implement).
pub fn baracuda_kernels_mmvq_q5_0_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// GGUF `Q5_1` MMVQ. # Safety: as `Q4_0`.
pub fn baracuda_kernels_mmvq_q5_1_run(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q5_1_can_implement` (baracuda kernels mmvq q5 1 can implement).
pub fn baracuda_kernels_mmvq_q5_1_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// GGUF `Q8_0` MMVQ. # Safety: as `Q4_0`.
pub fn baracuda_kernels_mmvq_q8_0_run(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q8_0_can_implement` (baracuda kernels mmvq q8 0 can implement).
pub fn baracuda_kernels_mmvq_q8_0_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
// ---- MMVQ: k-quants (256-element blocks) ----
//
// `Q8_K` MMVQ is exposed (Phase 11.4 — bespoke, not vendored). Upstream
// llama.cpp / Fuel ship only the Q8_K dequant kernel; baracuda adds a
// fused MMVQ to avoid the 2× memory traffic of dequant-then-GEMV.
/// GGUF `Q2_K` MMVQ — FP-activation matrix-vector mul.
/// `ncols` must be a multiple of 256.
/// # Safety: as `Q4_0`.
pub fn baracuda_kernels_mmvq_q2_K_run(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q2_K_can_implement` (baracuda kernels mmvq q2 k can implement).
pub fn baracuda_kernels_mmvq_q2_K_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// GGUF `Q3_K` MMVQ. # Safety: as `Q2_K`.
pub fn baracuda_kernels_mmvq_q3_K_run(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q3_K_can_implement` (baracuda kernels mmvq q3 k can implement).
pub fn baracuda_kernels_mmvq_q3_K_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// GGUF `Q4_K` MMVQ. # Safety: as `Q2_K`.
pub fn baracuda_kernels_mmvq_q4_K_run(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q4_K_can_implement` (baracuda kernels mmvq q4 k can implement).
pub fn baracuda_kernels_mmvq_q4_K_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// GGUF `Q5_K` MMVQ. # Safety: as `Q2_K`.
pub fn baracuda_kernels_mmvq_q5_K_run(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q5_K_can_implement` (baracuda kernels mmvq q5 k can implement).
pub fn baracuda_kernels_mmvq_q5_K_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// GGUF `Q6_K` MMVQ. # Safety: as `Q2_K`.
pub fn baracuda_kernels_mmvq_q6_K_run(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q6_K_can_implement` (baracuda kernels mmvq q6 k can implement).
pub fn baracuda_kernels_mmvq_q6_K_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// GGUF `Q8_K` MMVQ — Phase 11.4 (bespoke, not vendored from
/// llama.cpp). `ncols` must be a multiple of 256. # Safety: as `Q2_K`.
pub fn baracuda_kernels_mmvq_q8_K_run(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q8_K_can_implement` (baracuda kernels mmvq q8 k can implement).
pub fn baracuda_kernels_mmvq_q8_K_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
// ----- Activation-strided + W-offset MMVQ — Phase 14.5 ----------------
//
// Three runtime params added vs. the contig MMVQ FFI:
//
// * `w_start_byte_offset` (`i64`) — affine byte offset into W's
// allocation. Lets a single device buffer host multiple GGUF
// matrices; the launcher does the pointer arithmetic host-side
// and the kernel is unchanged on the W side. Zero cost at
// offset = 0.
// * `stride_y` (`i64`) — element stride along the activation's
// `ncols` axis. Signed.
// - `stride_y == 1` matches the contig kernel exactly.
// - `stride_y == 0` broadcasts the single-element activation
// across every column (the GQA "kv-head-axis-stride-0"
// degenerate case is still the rank-2 host-batched arm; at
// the kernel-level FFI this just shows up as the host
// launching one MMVQ call per Q-head with the same `y`
// pointer).
// - other values walk a strided view.
// * Activation rank is implicitly 1 at the kernel-level FFI; the
// host loops over higher dims (batch / sequence / Q-head) if
// needed. This keeps the kernel surface narrow — MMVQ is by
// construction matrix × vector.
//
// Status codes match the contig family: 0 success, 2 invalid problem,
// 5 internal launch error.
/// Strided MMVQ — GGUF `Q4_0`. # Safety: as the contig Q4_0 variant,
/// plus `(y[k * stride_y])_k=0..ncols` must be a valid f32 read.
pub fn baracuda_kernels_mmvq_q4_0_actstrided_run(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q4_0_actstrided_can_implement` (baracuda kernels mmvq q4 0 actstrided can implement).
pub fn baracuda_kernels_mmvq_q4_0_actstrided_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — GGUF `Q4_1`. # Safety: as the contig sibling.
pub fn baracuda_kernels_mmvq_q4_1_actstrided_run(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q4_1_actstrided_can_implement` (baracuda kernels mmvq q4 1 actstrided can implement).
pub fn baracuda_kernels_mmvq_q4_1_actstrided_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — GGUF `Q5_0`. # Safety: as the contig sibling.
pub fn baracuda_kernels_mmvq_q5_0_actstrided_run(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q5_0_actstrided_can_implement` (baracuda kernels mmvq q5 0 actstrided can implement).
pub fn baracuda_kernels_mmvq_q5_0_actstrided_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — GGUF `Q5_1`. # Safety: as the contig sibling.
pub fn baracuda_kernels_mmvq_q5_1_actstrided_run(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q5_1_actstrided_can_implement` (baracuda kernels mmvq q5 1 actstrided can implement).
pub fn baracuda_kernels_mmvq_q5_1_actstrided_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — GGUF `Q8_0`. # Safety: as the contig sibling.
pub fn baracuda_kernels_mmvq_q8_0_actstrided_run(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q8_0_actstrided_can_implement` (baracuda kernels mmvq q8 0 actstrided can implement).
pub fn baracuda_kernels_mmvq_q8_0_actstrided_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — GGUF `Q2_K`. # Safety: as the contig sibling.
pub fn baracuda_kernels_mmvq_q2_K_actstrided_run(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q2_K_actstrided_can_implement` (baracuda kernels mmvq q2 k actstrided can implement).
pub fn baracuda_kernels_mmvq_q2_K_actstrided_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — GGUF `Q3_K`. # Safety: as the contig sibling.
pub fn baracuda_kernels_mmvq_q3_K_actstrided_run(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q3_K_actstrided_can_implement` (baracuda kernels mmvq q3 k actstrided can implement).
pub fn baracuda_kernels_mmvq_q3_K_actstrided_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — GGUF `Q4_K`. # Safety: as the contig sibling.
pub fn baracuda_kernels_mmvq_q4_K_actstrided_run(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q4_K_actstrided_can_implement` (baracuda kernels mmvq q4 k actstrided can implement).
pub fn baracuda_kernels_mmvq_q4_K_actstrided_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — GGUF `Q5_K`. # Safety: as the contig sibling.
pub fn baracuda_kernels_mmvq_q5_K_actstrided_run(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q5_K_actstrided_can_implement` (baracuda kernels mmvq q5 k actstrided can implement).
pub fn baracuda_kernels_mmvq_q5_K_actstrided_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — GGUF `Q6_K`. # Safety: as the contig sibling.
pub fn baracuda_kernels_mmvq_q6_K_actstrided_run(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q6_K_actstrided_can_implement` (baracuda kernels mmvq q6 k actstrided can implement).
pub fn baracuda_kernels_mmvq_q6_K_actstrided_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — GGUF `Q8_K` (bespoke; Phase 11.4 + 14.5).
/// # Safety: as the contig sibling.
pub fn baracuda_kernels_mmvq_q8_K_actstrided_run(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q8_K_actstrided_can_implement` (baracuda kernels mmvq q8 k actstrided can implement).
pub fn baracuda_kernels_mmvq_q8_K_actstrided_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
// ===== Phase 18.1 — f16 / bf16 activation MMVQ ============================
//
// 11 block formats × {f16, bf16} × {contig, actstrided} = 44 new symbols.
//
// Convention vs. the f32 baseline:
// * `y` and `dst` share the same dtype (PyTorch convention: f16 activation
// → f16 dst; bf16 → bf16). The dtype is encoded in the symbol suffix
// (`_f16_run` / `_bf16_run` / `_actstrided_f16_run` / `_actstrided_bf16_run`).
// * The existing un-suffixed symbol (`*_run`, `*_actstrided_run`) keeps
// f32 activation + f32 output, preserved for back-compat.
// * Internal accumulator stays f32 in every variant — the f16 / bf16 cast
// happens at the load (activation read) and store (dst write) sites.
//
// Status codes match the f32 family: 0 success, 2 invalid problem,
// 5 internal launch error.
/// MMVQ — Q4_0, f16 activation + f16 output. # Safety: as the f32 sibling
/// with `y` / `dst` typed `__half` device-resident.
pub fn baracuda_kernels_mmvq_q4_0_f16_run(
ncols: i32, nrows: i32,
x: *const c_void, y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q4_0_f16_can_implement` (baracuda kernels mmvq q4 0 f16 can implement).
pub fn baracuda_kernels_mmvq_q4_0_f16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// MMVQ — Q4_0, bf16 activation + bf16 output. # Safety: as the f32 sibling
/// with `y` / `dst` typed `__nv_bfloat16` device-resident.
pub fn baracuda_kernels_mmvq_q4_0_bf16_run(
ncols: i32, nrows: i32,
x: *const c_void, y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q4_0_bf16_can_implement` (baracuda kernels mmvq q4 0 bf16 can implement).
pub fn baracuda_kernels_mmvq_q4_0_bf16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// MMVQ — Q4_1, f16 activation + f16 output. # Safety: as Q4_0 f16.
pub fn baracuda_kernels_mmvq_q4_1_f16_run(
ncols: i32, nrows: i32,
x: *const c_void, y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q4_1_f16_can_implement` (baracuda kernels mmvq q4 1 f16 can implement).
pub fn baracuda_kernels_mmvq_q4_1_f16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// MMVQ — Q4_1, bf16 activation + bf16 output. # Safety: as Q4_0 bf16.
pub fn baracuda_kernels_mmvq_q4_1_bf16_run(
ncols: i32, nrows: i32,
x: *const c_void, y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q4_1_bf16_can_implement` (baracuda kernels mmvq q4 1 bf16 can implement).
pub fn baracuda_kernels_mmvq_q4_1_bf16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// MMVQ — Q5_0, f16. # Safety: as Q4_0 f16.
pub fn baracuda_kernels_mmvq_q5_0_f16_run(
ncols: i32, nrows: i32,
x: *const c_void, y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q5_0_f16_can_implement` (baracuda kernels mmvq q5 0 f16 can implement).
pub fn baracuda_kernels_mmvq_q5_0_f16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// MMVQ — Q5_0, bf16. # Safety: as Q4_0 bf16.
pub fn baracuda_kernels_mmvq_q5_0_bf16_run(
ncols: i32, nrows: i32,
x: *const c_void, y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q5_0_bf16_can_implement` (baracuda kernels mmvq q5 0 bf16 can implement).
pub fn baracuda_kernels_mmvq_q5_0_bf16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// MMVQ — Q5_1, f16. # Safety: as Q4_0 f16.
pub fn baracuda_kernels_mmvq_q5_1_f16_run(
ncols: i32, nrows: i32,
x: *const c_void, y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q5_1_f16_can_implement` (baracuda kernels mmvq q5 1 f16 can implement).
pub fn baracuda_kernels_mmvq_q5_1_f16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// MMVQ — Q5_1, bf16. # Safety: as Q4_0 bf16.
pub fn baracuda_kernels_mmvq_q5_1_bf16_run(
ncols: i32, nrows: i32,
x: *const c_void, y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q5_1_bf16_can_implement` (baracuda kernels mmvq q5 1 bf16 can implement).
pub fn baracuda_kernels_mmvq_q5_1_bf16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// MMVQ — Q8_0, f16. # Safety: as Q4_0 f16.
pub fn baracuda_kernels_mmvq_q8_0_f16_run(
ncols: i32, nrows: i32,
x: *const c_void, y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q8_0_f16_can_implement` (baracuda kernels mmvq q8 0 f16 can implement).
pub fn baracuda_kernels_mmvq_q8_0_f16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// MMVQ — Q8_0, bf16. # Safety: as Q4_0 bf16.
pub fn baracuda_kernels_mmvq_q8_0_bf16_run(
ncols: i32, nrows: i32,
x: *const c_void, y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q8_0_bf16_can_implement` (baracuda kernels mmvq q8 0 bf16 can implement).
pub fn baracuda_kernels_mmvq_q8_0_bf16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// MMVQ — Q2_K, f16. # Safety: as Q4_0 f16, ncols must be multiple of 256.
pub fn baracuda_kernels_mmvq_q2_K_f16_run(
ncols: i32, nrows: i32,
x: *const c_void, y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q2_K_f16_can_implement` (baracuda kernels mmvq q2 k f16 can implement).
pub fn baracuda_kernels_mmvq_q2_K_f16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// MMVQ — Q2_K, bf16. # Safety: as Q4_0 bf16, ncols must be multiple of 256.
pub fn baracuda_kernels_mmvq_q2_K_bf16_run(
ncols: i32, nrows: i32,
x: *const c_void, y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q2_K_bf16_can_implement` (baracuda kernels mmvq q2 k bf16 can implement).
pub fn baracuda_kernels_mmvq_q2_K_bf16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// MMVQ — Q3_K, f16. # Safety: as Q2_K f16.
pub fn baracuda_kernels_mmvq_q3_K_f16_run(
ncols: i32, nrows: i32,
x: *const c_void, y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q3_K_f16_can_implement` (baracuda kernels mmvq q3 k f16 can implement).
pub fn baracuda_kernels_mmvq_q3_K_f16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// MMVQ — Q3_K, bf16. # Safety: as Q2_K bf16.
pub fn baracuda_kernels_mmvq_q3_K_bf16_run(
ncols: i32, nrows: i32,
x: *const c_void, y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q3_K_bf16_can_implement` (baracuda kernels mmvq q3 k bf16 can implement).
pub fn baracuda_kernels_mmvq_q3_K_bf16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// MMVQ — Q4_K, f16. # Safety: as Q2_K f16.
pub fn baracuda_kernels_mmvq_q4_K_f16_run(
ncols: i32, nrows: i32,
x: *const c_void, y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q4_K_f16_can_implement` (baracuda kernels mmvq q4 k f16 can implement).
pub fn baracuda_kernels_mmvq_q4_K_f16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// MMVQ — Q4_K, bf16. # Safety: as Q2_K bf16.
pub fn baracuda_kernels_mmvq_q4_K_bf16_run(
ncols: i32, nrows: i32,
x: *const c_void, y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q4_K_bf16_can_implement` (baracuda kernels mmvq q4 k bf16 can implement).
pub fn baracuda_kernels_mmvq_q4_K_bf16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// MMVQ — Q5_K, f16. # Safety: as Q2_K f16.
pub fn baracuda_kernels_mmvq_q5_K_f16_run(
ncols: i32, nrows: i32,
x: *const c_void, y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q5_K_f16_can_implement` (baracuda kernels mmvq q5 k f16 can implement).
pub fn baracuda_kernels_mmvq_q5_K_f16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// MMVQ — Q5_K, bf16. # Safety: as Q2_K bf16.
pub fn baracuda_kernels_mmvq_q5_K_bf16_run(
ncols: i32, nrows: i32,
x: *const c_void, y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q5_K_bf16_can_implement` (baracuda kernels mmvq q5 k bf16 can implement).
pub fn baracuda_kernels_mmvq_q5_K_bf16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// MMVQ — Q6_K, f16. # Safety: as Q2_K f16.
pub fn baracuda_kernels_mmvq_q6_K_f16_run(
ncols: i32, nrows: i32,
x: *const c_void, y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q6_K_f16_can_implement` (baracuda kernels mmvq q6 k f16 can implement).
pub fn baracuda_kernels_mmvq_q6_K_f16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// MMVQ — Q6_K, bf16. # Safety: as Q2_K bf16.
pub fn baracuda_kernels_mmvq_q6_K_bf16_run(
ncols: i32, nrows: i32,
x: *const c_void, y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q6_K_bf16_can_implement` (baracuda kernels mmvq q6 k bf16 can implement).
pub fn baracuda_kernels_mmvq_q6_K_bf16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// MMVQ — Q8_K, f16 (bespoke; Phase 11.4 + 18.1). # Safety: as Q2_K f16.
pub fn baracuda_kernels_mmvq_q8_K_f16_run(
ncols: i32, nrows: i32,
x: *const c_void, y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q8_K_f16_can_implement` (baracuda kernels mmvq q8 k f16 can implement).
pub fn baracuda_kernels_mmvq_q8_K_f16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// MMVQ — Q8_K, bf16 (bespoke; Phase 11.4 + 18.1). # Safety: as Q2_K bf16.
pub fn baracuda_kernels_mmvq_q8_K_bf16_run(
ncols: i32, nrows: i32,
x: *const c_void, y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q8_K_bf16_can_implement` (baracuda kernels mmvq q8 k bf16 can implement).
pub fn baracuda_kernels_mmvq_q8_K_bf16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
y: *const c_void,
dst: *const c_void,
) -> i32;
// ---- Phase 18.1 strided f16 / bf16 siblings ------------------------------
/// Strided MMVQ — Q4_0, f16. # Safety: as the f32 strided sibling.
pub fn baracuda_kernels_mmvq_q4_0_actstrided_f16_run(
ncols: i32, nrows: i32, x: *const c_void,
w_start_byte_offset: i64, stride_y: i64,
y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q4_0_actstrided_f16_can_implement` (baracuda kernels mmvq q4 0 actstrided f16 can implement).
pub fn baracuda_kernels_mmvq_q4_0_actstrided_f16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — Q4_0, bf16. # Safety: as the f32 strided sibling.
pub fn baracuda_kernels_mmvq_q4_0_actstrided_bf16_run(
ncols: i32, nrows: i32, x: *const c_void,
w_start_byte_offset: i64, stride_y: i64,
y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q4_0_actstrided_bf16_can_implement` (baracuda kernels mmvq q4 0 actstrided bf16 can implement).
pub fn baracuda_kernels_mmvq_q4_0_actstrided_bf16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — Q4_1, f16. # Safety: as Q4_0 strided f16.
pub fn baracuda_kernels_mmvq_q4_1_actstrided_f16_run(
ncols: i32, nrows: i32, x: *const c_void,
w_start_byte_offset: i64, stride_y: i64,
y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q4_1_actstrided_f16_can_implement` (baracuda kernels mmvq q4 1 actstrided f16 can implement).
pub fn baracuda_kernels_mmvq_q4_1_actstrided_f16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — Q4_1, bf16. # Safety: as Q4_0 strided bf16.
pub fn baracuda_kernels_mmvq_q4_1_actstrided_bf16_run(
ncols: i32, nrows: i32, x: *const c_void,
w_start_byte_offset: i64, stride_y: i64,
y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q4_1_actstrided_bf16_can_implement` (baracuda kernels mmvq q4 1 actstrided bf16 can implement).
pub fn baracuda_kernels_mmvq_q4_1_actstrided_bf16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — Q5_0, f16. # Safety: as Q4_0 strided f16.
pub fn baracuda_kernels_mmvq_q5_0_actstrided_f16_run(
ncols: i32, nrows: i32, x: *const c_void,
w_start_byte_offset: i64, stride_y: i64,
y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q5_0_actstrided_f16_can_implement` (baracuda kernels mmvq q5 0 actstrided f16 can implement).
pub fn baracuda_kernels_mmvq_q5_0_actstrided_f16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — Q5_0, bf16. # Safety: as Q4_0 strided bf16.
pub fn baracuda_kernels_mmvq_q5_0_actstrided_bf16_run(
ncols: i32, nrows: i32, x: *const c_void,
w_start_byte_offset: i64, stride_y: i64,
y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q5_0_actstrided_bf16_can_implement` (baracuda kernels mmvq q5 0 actstrided bf16 can implement).
pub fn baracuda_kernels_mmvq_q5_0_actstrided_bf16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — Q5_1, f16. # Safety: as Q4_0 strided f16.
pub fn baracuda_kernels_mmvq_q5_1_actstrided_f16_run(
ncols: i32, nrows: i32, x: *const c_void,
w_start_byte_offset: i64, stride_y: i64,
y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q5_1_actstrided_f16_can_implement` (baracuda kernels mmvq q5 1 actstrided f16 can implement).
pub fn baracuda_kernels_mmvq_q5_1_actstrided_f16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — Q5_1, bf16. # Safety: as Q4_0 strided bf16.
pub fn baracuda_kernels_mmvq_q5_1_actstrided_bf16_run(
ncols: i32, nrows: i32, x: *const c_void,
w_start_byte_offset: i64, stride_y: i64,
y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q5_1_actstrided_bf16_can_implement` (baracuda kernels mmvq q5 1 actstrided bf16 can implement).
pub fn baracuda_kernels_mmvq_q5_1_actstrided_bf16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — Q8_0, f16. # Safety: as Q4_0 strided f16.
pub fn baracuda_kernels_mmvq_q8_0_actstrided_f16_run(
ncols: i32, nrows: i32, x: *const c_void,
w_start_byte_offset: i64, stride_y: i64,
y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q8_0_actstrided_f16_can_implement` (baracuda kernels mmvq q8 0 actstrided f16 can implement).
pub fn baracuda_kernels_mmvq_q8_0_actstrided_f16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — Q8_0, bf16. # Safety: as Q4_0 strided bf16.
pub fn baracuda_kernels_mmvq_q8_0_actstrided_bf16_run(
ncols: i32, nrows: i32, x: *const c_void,
w_start_byte_offset: i64, stride_y: i64,
y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q8_0_actstrided_bf16_can_implement` (baracuda kernels mmvq q8 0 actstrided bf16 can implement).
pub fn baracuda_kernels_mmvq_q8_0_actstrided_bf16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — Q2_K, f16. # Safety: as Q4_0 strided f16, ncols mul of 256.
pub fn baracuda_kernels_mmvq_q2_K_actstrided_f16_run(
ncols: i32, nrows: i32, x: *const c_void,
w_start_byte_offset: i64, stride_y: i64,
y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q2_K_actstrided_f16_can_implement` (baracuda kernels mmvq q2 k actstrided f16 can implement).
pub fn baracuda_kernels_mmvq_q2_K_actstrided_f16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — Q2_K, bf16. # Safety: as Q2_K strided f16.
pub fn baracuda_kernels_mmvq_q2_K_actstrided_bf16_run(
ncols: i32, nrows: i32, x: *const c_void,
w_start_byte_offset: i64, stride_y: i64,
y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q2_K_actstrided_bf16_can_implement` (baracuda kernels mmvq q2 k actstrided bf16 can implement).
pub fn baracuda_kernels_mmvq_q2_K_actstrided_bf16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — Q3_K, f16. # Safety: as Q2_K strided f16.
pub fn baracuda_kernels_mmvq_q3_K_actstrided_f16_run(
ncols: i32, nrows: i32, x: *const c_void,
w_start_byte_offset: i64, stride_y: i64,
y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q3_K_actstrided_f16_can_implement` (baracuda kernels mmvq q3 k actstrided f16 can implement).
pub fn baracuda_kernels_mmvq_q3_K_actstrided_f16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — Q3_K, bf16. # Safety: as Q2_K strided bf16.
pub fn baracuda_kernels_mmvq_q3_K_actstrided_bf16_run(
ncols: i32, nrows: i32, x: *const c_void,
w_start_byte_offset: i64, stride_y: i64,
y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q3_K_actstrided_bf16_can_implement` (baracuda kernels mmvq q3 k actstrided bf16 can implement).
pub fn baracuda_kernels_mmvq_q3_K_actstrided_bf16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — Q4_K, f16. # Safety: as Q2_K strided f16.
pub fn baracuda_kernels_mmvq_q4_K_actstrided_f16_run(
ncols: i32, nrows: i32, x: *const c_void,
w_start_byte_offset: i64, stride_y: i64,
y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q4_K_actstrided_f16_can_implement` (baracuda kernels mmvq q4 k actstrided f16 can implement).
pub fn baracuda_kernels_mmvq_q4_K_actstrided_f16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — Q4_K, bf16. # Safety: as Q2_K strided bf16.
pub fn baracuda_kernels_mmvq_q4_K_actstrided_bf16_run(
ncols: i32, nrows: i32, x: *const c_void,
w_start_byte_offset: i64, stride_y: i64,
y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q4_K_actstrided_bf16_can_implement` (baracuda kernels mmvq q4 k actstrided bf16 can implement).
pub fn baracuda_kernels_mmvq_q4_K_actstrided_bf16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — Q5_K, f16. # Safety: as Q2_K strided f16.
pub fn baracuda_kernels_mmvq_q5_K_actstrided_f16_run(
ncols: i32, nrows: i32, x: *const c_void,
w_start_byte_offset: i64, stride_y: i64,
y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q5_K_actstrided_f16_can_implement` (baracuda kernels mmvq q5 k actstrided f16 can implement).
pub fn baracuda_kernels_mmvq_q5_K_actstrided_f16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — Q5_K, bf16. # Safety: as Q2_K strided bf16.
pub fn baracuda_kernels_mmvq_q5_K_actstrided_bf16_run(
ncols: i32, nrows: i32, x: *const c_void,
w_start_byte_offset: i64, stride_y: i64,
y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q5_K_actstrided_bf16_can_implement` (baracuda kernels mmvq q5 k actstrided bf16 can implement).
pub fn baracuda_kernels_mmvq_q5_K_actstrided_bf16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — Q6_K, f16. # Safety: as Q2_K strided f16.
pub fn baracuda_kernels_mmvq_q6_K_actstrided_f16_run(
ncols: i32, nrows: i32, x: *const c_void,
w_start_byte_offset: i64, stride_y: i64,
y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q6_K_actstrided_f16_can_implement` (baracuda kernels mmvq q6 k actstrided f16 can implement).
pub fn baracuda_kernels_mmvq_q6_K_actstrided_f16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — Q6_K, bf16. # Safety: as Q2_K strided bf16.
pub fn baracuda_kernels_mmvq_q6_K_actstrided_bf16_run(
ncols: i32, nrows: i32, x: *const c_void,
w_start_byte_offset: i64, stride_y: i64,
y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q6_K_actstrided_bf16_can_implement` (baracuda kernels mmvq q6 k actstrided bf16 can implement).
pub fn baracuda_kernels_mmvq_q6_K_actstrided_bf16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — Q8_K, f16 (bespoke). # Safety: as Q2_K strided f16.
pub fn baracuda_kernels_mmvq_q8_K_actstrided_f16_run(
ncols: i32, nrows: i32, x: *const c_void,
w_start_byte_offset: i64, stride_y: i64,
y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q8_K_actstrided_f16_can_implement` (baracuda kernels mmvq q8 k actstrided f16 can implement).
pub fn baracuda_kernels_mmvq_q8_K_actstrided_f16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
/// Strided MMVQ — Q8_K, bf16 (bespoke). # Safety: as Q2_K strided bf16.
pub fn baracuda_kernels_mmvq_q8_K_actstrided_bf16_run(
ncols: i32, nrows: i32, x: *const c_void,
w_start_byte_offset: i64, stride_y: i64,
y: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q8_K_actstrided_bf16_can_implement` (baracuda kernels mmvq q8 k actstrided bf16 can implement).
pub fn baracuda_kernels_mmvq_q8_K_actstrided_bf16_can_implement(
ncols: i32,
nrows: i32,
x: *const c_void,
w_start_byte_offset: i64,
stride_y: i64,
y: *const c_void,
dst: *const c_void,
) -> i32;
// ===== Phase 33 — multi-M MMVQ via Q8_1 activation staging =================
//
// The multi-M MMVQ kernel reuses one weight load across `ncols_y ∈
// {1, 2, 4, 8}` activation vectors. The activations must first be
// staged into the Q8_1 block format (int8 quants + half scale + half
// sum, 36 bytes per 32-element block) by the `quantize_q8_1_*_run`
// family. The dot product is computed via `__dp4a` SIMD int8 MACs
// with an fp32 fixup at the block boundary.
//
// Scope: Q8_0 weights only in Phase 33 (the simplest dot — pure
// `__dp4a` × VDR=2 + scalar rescale). Remaining 9 block formats
// (Q4_0/Q4_1/Q5_0/Q5_1/Q2_K..Q6_K) land in a follow-up phase.
/// Q8_1 activation staging — f32 source.
///
/// Converts `ny × kx` f32 activations into the Q8_1 block format.
/// Output buffer size: `ny × ceil(kx / 32) × 36 bytes` (use
/// `baracuda_kernels_quantize_q8_1_workspace_bytes` to compute).
/// `kx` rounded up to multiple of 32 internally; out-of-range
/// columns zero-padded.
///
/// # Safety: device-resident src + dst; valid stream.
pub fn baracuda_kernels_quantize_q8_1_f32_run(
kx: i64, ny: i64,
src: *const c_void,
dst_q8_1: *mut c_void,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_quantize_q8_1_f32_can_implement` (baracuda kernels quantize q8 1 f32 can implement).
pub fn baracuda_kernels_quantize_q8_1_f32_can_implement(
kx: i64,
ny: i64,
src: *const c_void,
dst_q8_1: *const c_void,
) -> i32;
/// Q8_1 activation staging — f16 source. # Safety: as f32 variant.
pub fn baracuda_kernels_quantize_q8_1_f16_run(
kx: i64, ny: i64,
src: *const c_void,
dst_q8_1: *mut c_void,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_quantize_q8_1_f16_can_implement` (baracuda kernels quantize q8 1 f16 can implement).
pub fn baracuda_kernels_quantize_q8_1_f16_can_implement(
kx: i64,
ny: i64,
src: *const c_void,
dst_q8_1: *const c_void,
) -> i32;
/// Q8_1 activation staging — bf16 source. # Safety: as f32 variant.
pub fn baracuda_kernels_quantize_q8_1_bf16_run(
kx: i64, ny: i64,
src: *const c_void,
dst_q8_1: *mut c_void,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_quantize_q8_1_bf16_can_implement` (baracuda kernels quantize q8 1 bf16 can implement).
pub fn baracuda_kernels_quantize_q8_1_bf16_can_implement(
kx: i64,
ny: i64,
src: *const c_void,
dst_q8_1: *const c_void,
) -> i32;
/// Returns workspace bytes needed to stage `ny × kx` activations
/// into Q8_1. = `ny * ceil(kx / 32) * 36`. Returns 0 on invalid
/// (non-positive) arguments.
pub fn baracuda_kernels_quantize_q8_1_workspace_bytes(
kx: i64, ny: i64,
) -> i64;
/// Multi-M MMVQ for Q8_0 weights, M=1 (decode regime). Computes
/// `dst[0, r] = Σ_c W[r, c] * y[0, c]` for r ∈ [0, nrows_x).
///
/// # Safety: device-resident weight + Q8_1-staged activations + dst;
/// valid stream. `ncols_x` must be a multiple of 32.
pub fn baracuda_kernels_mmvq_multim_q8_0_m1_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q8_0_m1_can_implement` (baracuda kernels mmvq multim q8 0 m1 can implement).
pub fn baracuda_kernels_mmvq_multim_q8_0_m1_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// Multi-M MMVQ for Q8_0 weights, M=2. # Safety: as M=1.
pub fn baracuda_kernels_mmvq_multim_q8_0_m2_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q8_0_m2_can_implement` (baracuda kernels mmvq multim q8 0 m2 can implement).
pub fn baracuda_kernels_mmvq_multim_q8_0_m2_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// Multi-M MMVQ for Q8_0 weights, M=4. # Safety: as M=1.
pub fn baracuda_kernels_mmvq_multim_q8_0_m4_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q8_0_m4_can_implement` (baracuda kernels mmvq multim q8 0 m4 can implement).
pub fn baracuda_kernels_mmvq_multim_q8_0_m4_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// Multi-M MMVQ for Q8_0 weights, M=8 (prefill regime, target 3-7×
/// vs the per-token M=1 dispatch). # Safety: as M=1.
pub fn baracuda_kernels_mmvq_multim_q8_0_m8_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q8_0_m8_can_implement` (baracuda kernels mmvq multim q8 0 m8 can implement).
pub fn baracuda_kernels_mmvq_multim_q8_0_m8_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
// ===== Phase 34 — multi-M MMVQ block-format fanout =========================
//
// 9 GGUF block formats × 4 M-sizes (1, 2, 4, 8) = 36 new FFI symbols.
// All share the same ABI as the Q8_0 family above; `ncols_x` must be a
// multiple of the block size (32 for Q4_0/Q4_1/Q5_0/Q5_1; 256 for the
// k-quants Q2_K/Q3_K/Q4_K/Q5_K/Q6_K).
//
// Q4_0
/// `baracuda_kernels_mmvq_multim_q4_0_m1_run` (baracuda kernels mmvq multim q4 0 m1 run).
pub fn baracuda_kernels_mmvq_multim_q4_0_m1_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q4_0_m1_can_implement` (baracuda kernels mmvq multim q4 0 m1 can implement).
pub fn baracuda_kernels_mmvq_multim_q4_0_m1_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q4_0_m2_run` (baracuda kernels mmvq multim q4 0 m2 run).
pub fn baracuda_kernels_mmvq_multim_q4_0_m2_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q4_0_m2_can_implement` (baracuda kernels mmvq multim q4 0 m2 can implement).
pub fn baracuda_kernels_mmvq_multim_q4_0_m2_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q4_0_m4_run` (baracuda kernels mmvq multim q4 0 m4 run).
pub fn baracuda_kernels_mmvq_multim_q4_0_m4_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q4_0_m4_can_implement` (baracuda kernels mmvq multim q4 0 m4 can implement).
pub fn baracuda_kernels_mmvq_multim_q4_0_m4_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q4_0_m8_run` (baracuda kernels mmvq multim q4 0 m8 run).
pub fn baracuda_kernels_mmvq_multim_q4_0_m8_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q4_0_m8_can_implement` (baracuda kernels mmvq multim q4 0 m8 can implement).
pub fn baracuda_kernels_mmvq_multim_q4_0_m8_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
// Q4_1
/// `baracuda_kernels_mmvq_multim_q4_1_m1_run` (baracuda kernels mmvq multim q4 1 m1 run).
pub fn baracuda_kernels_mmvq_multim_q4_1_m1_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q4_1_m1_can_implement` (baracuda kernels mmvq multim q4 1 m1 can implement).
pub fn baracuda_kernels_mmvq_multim_q4_1_m1_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q4_1_m2_run` (baracuda kernels mmvq multim q4 1 m2 run).
pub fn baracuda_kernels_mmvq_multim_q4_1_m2_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q4_1_m2_can_implement` (baracuda kernels mmvq multim q4 1 m2 can implement).
pub fn baracuda_kernels_mmvq_multim_q4_1_m2_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q4_1_m4_run` (baracuda kernels mmvq multim q4 1 m4 run).
pub fn baracuda_kernels_mmvq_multim_q4_1_m4_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q4_1_m4_can_implement` (baracuda kernels mmvq multim q4 1 m4 can implement).
pub fn baracuda_kernels_mmvq_multim_q4_1_m4_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q4_1_m8_run` (baracuda kernels mmvq multim q4 1 m8 run).
pub fn baracuda_kernels_mmvq_multim_q4_1_m8_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q4_1_m8_can_implement` (baracuda kernels mmvq multim q4 1 m8 can implement).
pub fn baracuda_kernels_mmvq_multim_q4_1_m8_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
// Q5_0
/// `baracuda_kernels_mmvq_multim_q5_0_m1_run` (baracuda kernels mmvq multim q5 0 m1 run).
pub fn baracuda_kernels_mmvq_multim_q5_0_m1_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q5_0_m1_can_implement` (baracuda kernels mmvq multim q5 0 m1 can implement).
pub fn baracuda_kernels_mmvq_multim_q5_0_m1_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q5_0_m2_run` (baracuda kernels mmvq multim q5 0 m2 run).
pub fn baracuda_kernels_mmvq_multim_q5_0_m2_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q5_0_m2_can_implement` (baracuda kernels mmvq multim q5 0 m2 can implement).
pub fn baracuda_kernels_mmvq_multim_q5_0_m2_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q5_0_m4_run` (baracuda kernels mmvq multim q5 0 m4 run).
pub fn baracuda_kernels_mmvq_multim_q5_0_m4_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q5_0_m4_can_implement` (baracuda kernels mmvq multim q5 0 m4 can implement).
pub fn baracuda_kernels_mmvq_multim_q5_0_m4_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q5_0_m8_run` (baracuda kernels mmvq multim q5 0 m8 run).
pub fn baracuda_kernels_mmvq_multim_q5_0_m8_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q5_0_m8_can_implement` (baracuda kernels mmvq multim q5 0 m8 can implement).
pub fn baracuda_kernels_mmvq_multim_q5_0_m8_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
// Q5_1
/// `baracuda_kernels_mmvq_multim_q5_1_m1_run` (baracuda kernels mmvq multim q5 1 m1 run).
pub fn baracuda_kernels_mmvq_multim_q5_1_m1_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q5_1_m1_can_implement` (baracuda kernels mmvq multim q5 1 m1 can implement).
pub fn baracuda_kernels_mmvq_multim_q5_1_m1_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q5_1_m2_run` (baracuda kernels mmvq multim q5 1 m2 run).
pub fn baracuda_kernels_mmvq_multim_q5_1_m2_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q5_1_m2_can_implement` (baracuda kernels mmvq multim q5 1 m2 can implement).
pub fn baracuda_kernels_mmvq_multim_q5_1_m2_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q5_1_m4_run` (baracuda kernels mmvq multim q5 1 m4 run).
pub fn baracuda_kernels_mmvq_multim_q5_1_m4_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q5_1_m4_can_implement` (baracuda kernels mmvq multim q5 1 m4 can implement).
pub fn baracuda_kernels_mmvq_multim_q5_1_m4_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q5_1_m8_run` (baracuda kernels mmvq multim q5 1 m8 run).
pub fn baracuda_kernels_mmvq_multim_q5_1_m8_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q5_1_m8_can_implement` (baracuda kernels mmvq multim q5 1 m8 can implement).
pub fn baracuda_kernels_mmvq_multim_q5_1_m8_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
// Q2_K
/// `baracuda_kernels_mmvq_multim_q2_K_m1_run` (baracuda kernels mmvq multim q2 k m1 run).
pub fn baracuda_kernels_mmvq_multim_q2_K_m1_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q2_K_m1_can_implement` (baracuda kernels mmvq multim q2 k m1 can implement).
pub fn baracuda_kernels_mmvq_multim_q2_K_m1_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q2_K_m2_run` (baracuda kernels mmvq multim q2 k m2 run).
pub fn baracuda_kernels_mmvq_multim_q2_K_m2_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q2_K_m2_can_implement` (baracuda kernels mmvq multim q2 k m2 can implement).
pub fn baracuda_kernels_mmvq_multim_q2_K_m2_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q2_K_m4_run` (baracuda kernels mmvq multim q2 k m4 run).
pub fn baracuda_kernels_mmvq_multim_q2_K_m4_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q2_K_m4_can_implement` (baracuda kernels mmvq multim q2 k m4 can implement).
pub fn baracuda_kernels_mmvq_multim_q2_K_m4_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q2_K_m8_run` (baracuda kernels mmvq multim q2 k m8 run).
pub fn baracuda_kernels_mmvq_multim_q2_K_m8_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q2_K_m8_can_implement` (baracuda kernels mmvq multim q2 k m8 can implement).
pub fn baracuda_kernels_mmvq_multim_q2_K_m8_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
// Q3_K
/// `baracuda_kernels_mmvq_multim_q3_K_m1_run` (baracuda kernels mmvq multim q3 k m1 run).
pub fn baracuda_kernels_mmvq_multim_q3_K_m1_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q3_K_m1_can_implement` (baracuda kernels mmvq multim q3 k m1 can implement).
pub fn baracuda_kernels_mmvq_multim_q3_K_m1_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q3_K_m2_run` (baracuda kernels mmvq multim q3 k m2 run).
pub fn baracuda_kernels_mmvq_multim_q3_K_m2_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q3_K_m2_can_implement` (baracuda kernels mmvq multim q3 k m2 can implement).
pub fn baracuda_kernels_mmvq_multim_q3_K_m2_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q3_K_m4_run` (baracuda kernels mmvq multim q3 k m4 run).
pub fn baracuda_kernels_mmvq_multim_q3_K_m4_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q3_K_m4_can_implement` (baracuda kernels mmvq multim q3 k m4 can implement).
pub fn baracuda_kernels_mmvq_multim_q3_K_m4_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q3_K_m8_run` (baracuda kernels mmvq multim q3 k m8 run).
pub fn baracuda_kernels_mmvq_multim_q3_K_m8_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q3_K_m8_can_implement` (baracuda kernels mmvq multim q3 k m8 can implement).
pub fn baracuda_kernels_mmvq_multim_q3_K_m8_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
// Q4_K
/// `baracuda_kernels_mmvq_multim_q4_K_m1_run` (baracuda kernels mmvq multim q4 k m1 run).
pub fn baracuda_kernels_mmvq_multim_q4_K_m1_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q4_K_m1_can_implement` (baracuda kernels mmvq multim q4 k m1 can implement).
pub fn baracuda_kernels_mmvq_multim_q4_K_m1_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q4_K_m2_run` (baracuda kernels mmvq multim q4 k m2 run).
pub fn baracuda_kernels_mmvq_multim_q4_K_m2_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q4_K_m2_can_implement` (baracuda kernels mmvq multim q4 k m2 can implement).
pub fn baracuda_kernels_mmvq_multim_q4_K_m2_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q4_K_m4_run` (baracuda kernels mmvq multim q4 k m4 run).
pub fn baracuda_kernels_mmvq_multim_q4_K_m4_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q4_K_m4_can_implement` (baracuda kernels mmvq multim q4 k m4 can implement).
pub fn baracuda_kernels_mmvq_multim_q4_K_m4_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q4_K_m8_run` (baracuda kernels mmvq multim q4 k m8 run).
pub fn baracuda_kernels_mmvq_multim_q4_K_m8_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q4_K_m8_can_implement` (baracuda kernels mmvq multim q4 k m8 can implement).
pub fn baracuda_kernels_mmvq_multim_q4_K_m8_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
// Q5_K
/// `baracuda_kernels_mmvq_multim_q5_K_m1_run` (baracuda kernels mmvq multim q5 k m1 run).
pub fn baracuda_kernels_mmvq_multim_q5_K_m1_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q5_K_m1_can_implement` (baracuda kernels mmvq multim q5 k m1 can implement).
pub fn baracuda_kernels_mmvq_multim_q5_K_m1_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q5_K_m2_run` (baracuda kernels mmvq multim q5 k m2 run).
pub fn baracuda_kernels_mmvq_multim_q5_K_m2_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q5_K_m2_can_implement` (baracuda kernels mmvq multim q5 k m2 can implement).
pub fn baracuda_kernels_mmvq_multim_q5_K_m2_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q5_K_m4_run` (baracuda kernels mmvq multim q5 k m4 run).
pub fn baracuda_kernels_mmvq_multim_q5_K_m4_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q5_K_m4_can_implement` (baracuda kernels mmvq multim q5 k m4 can implement).
pub fn baracuda_kernels_mmvq_multim_q5_K_m4_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q5_K_m8_run` (baracuda kernels mmvq multim q5 k m8 run).
pub fn baracuda_kernels_mmvq_multim_q5_K_m8_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q5_K_m8_can_implement` (baracuda kernels mmvq multim q5 k m8 can implement).
pub fn baracuda_kernels_mmvq_multim_q5_K_m8_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
// Q6_K
/// `baracuda_kernels_mmvq_multim_q6_K_m1_run` (baracuda kernels mmvq multim q6 k m1 run).
pub fn baracuda_kernels_mmvq_multim_q6_K_m1_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q6_K_m1_can_implement` (baracuda kernels mmvq multim q6 k m1 can implement).
pub fn baracuda_kernels_mmvq_multim_q6_K_m1_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q6_K_m2_run` (baracuda kernels mmvq multim q6 k m2 run).
pub fn baracuda_kernels_mmvq_multim_q6_K_m2_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q6_K_m2_can_implement` (baracuda kernels mmvq multim q6 k m2 can implement).
pub fn baracuda_kernels_mmvq_multim_q6_K_m2_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q6_K_m4_run` (baracuda kernels mmvq multim q6 k m4 run).
pub fn baracuda_kernels_mmvq_multim_q6_K_m4_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q6_K_m4_can_implement` (baracuda kernels mmvq multim q6 k m4 can implement).
pub fn baracuda_kernels_mmvq_multim_q6_K_m4_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q6_K_m8_run` (baracuda kernels mmvq multim q6 k m8 run).
pub fn baracuda_kernels_mmvq_multim_q6_K_m8_run(
ncols_x: i32, nrows_x: i32,
w_ptr: *const c_void, w_start_byte_offset: i64,
activations_q8_1: *const c_void, dst: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_multim_q6_K_m8_can_implement` (baracuda kernels mmvq multim q6 k m8 can implement).
pub fn baracuda_kernels_mmvq_multim_q6_K_m8_can_implement(
ncols_x: i32,
nrows_x: i32,
w_ptr: *const c_void,
w_start_byte_offset: i64,
activations_q8_1: *const c_void,
dst: *const c_void,
) -> i32;
// ===== Phase 20.1 — batched MMVQ × N-experts ==============================
//
// 11 block formats × 3 activation dtypes (f32 / f16 / bf16) = 33 quant
// symbols, plus 3 pure-FP symbols = 36 new FFI entry points.
//
// Op shape:
// For each dispatch i ∈ [0, M_total):
// token = sorted_token_ids[i]
// expert = find_e(expert_offsets, i) // precomputed once via prelude
// w = topk_weights[i] (or 1.0 if topk_weights == nullptr)
// for r ∈ [0, n_rows_per_expert):
// output[token, r] (+)= w * dot(weights[expert, r, :], activations[token, :])
//
// (+)= is a regular store when top_k == 1 (no output aliasing) and an
// atomicAdd (via `baracuda::atomic::add<T>`) when top_k > 1. The caller
// MUST zero-initialize `output` before the call when top_k > 1.
//
// Convention:
// * `_batched_run` — f32 activation + output (un-suffixed = canonical).
// * `_batched_f16_run` — f16 activation + output.
// * `_batched_bf16_run` — bf16 activation + output.
// * `mmvq_batched_<dt>_run` — pure-FP (no quant) variants.
//
// Workspace: `m_total * sizeof(i32)` bytes. The launcher derives
// `m_total = workspace_bytes / sizeof(i32)` (caller's responsibility to
// size accurately). Buffer is used for the `dispatch_to_expert[]`
// prelude (avoids per-block binary search over `expert_offsets[]`).
//
// Status codes: 0 success, 2 invalid problem (nullptrs / non-positive
// dims), 4 workspace too small, 5 internal launch failure.
// ---- Quant: type-0/1 (f32 = un-suffixed canonical) -------------------
/// Batched MMVQ — Q4_0, f32 activation + output. # Safety: device-resident
/// pointers; valid stream; `workspace` ≥ `m_total * 4` bytes.
pub fn baracuda_kernels_mmvq_q4_0_batched_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q4_0_batched_can_implement` (baracuda kernels mmvq q4 0 batched can implement).
pub fn baracuda_kernels_mmvq_q4_0_batched_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q4_0, f16. # Safety: as Q4_0 f32.
pub fn baracuda_kernels_mmvq_q4_0_batched_f16_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q4_0_batched_f16_can_implement` (baracuda kernels mmvq q4 0 batched f16 can implement).
pub fn baracuda_kernels_mmvq_q4_0_batched_f16_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q4_0, bf16. # Safety: as Q4_0 f32.
pub fn baracuda_kernels_mmvq_q4_0_batched_bf16_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q4_0_batched_bf16_can_implement` (baracuda kernels mmvq q4 0 batched bf16 can implement).
pub fn baracuda_kernels_mmvq_q4_0_batched_bf16_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q4_1, f32. # Safety: as Q4_0.
pub fn baracuda_kernels_mmvq_q4_1_batched_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q4_1_batched_can_implement` (baracuda kernels mmvq q4 1 batched can implement).
pub fn baracuda_kernels_mmvq_q4_1_batched_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q4_1, f16. # Safety: as Q4_0.
pub fn baracuda_kernels_mmvq_q4_1_batched_f16_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q4_1_batched_f16_can_implement` (baracuda kernels mmvq q4 1 batched f16 can implement).
pub fn baracuda_kernels_mmvq_q4_1_batched_f16_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q4_1, bf16. # Safety: as Q4_0.
pub fn baracuda_kernels_mmvq_q4_1_batched_bf16_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q4_1_batched_bf16_can_implement` (baracuda kernels mmvq q4 1 batched bf16 can implement).
pub fn baracuda_kernels_mmvq_q4_1_batched_bf16_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q5_0, f32. # Safety: as Q4_0.
pub fn baracuda_kernels_mmvq_q5_0_batched_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q5_0_batched_can_implement` (baracuda kernels mmvq q5 0 batched can implement).
pub fn baracuda_kernels_mmvq_q5_0_batched_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q5_0, f16. # Safety: as Q4_0.
pub fn baracuda_kernels_mmvq_q5_0_batched_f16_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q5_0_batched_f16_can_implement` (baracuda kernels mmvq q5 0 batched f16 can implement).
pub fn baracuda_kernels_mmvq_q5_0_batched_f16_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q5_0, bf16. # Safety: as Q4_0.
pub fn baracuda_kernels_mmvq_q5_0_batched_bf16_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q5_0_batched_bf16_can_implement` (baracuda kernels mmvq q5 0 batched bf16 can implement).
pub fn baracuda_kernels_mmvq_q5_0_batched_bf16_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q5_1, f32. # Safety: as Q4_0.
pub fn baracuda_kernels_mmvq_q5_1_batched_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q5_1_batched_can_implement` (baracuda kernels mmvq q5 1 batched can implement).
pub fn baracuda_kernels_mmvq_q5_1_batched_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q5_1, f16. # Safety: as Q4_0.
pub fn baracuda_kernels_mmvq_q5_1_batched_f16_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q5_1_batched_f16_can_implement` (baracuda kernels mmvq q5 1 batched f16 can implement).
pub fn baracuda_kernels_mmvq_q5_1_batched_f16_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q5_1, bf16. # Safety: as Q4_0.
pub fn baracuda_kernels_mmvq_q5_1_batched_bf16_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q5_1_batched_bf16_can_implement` (baracuda kernels mmvq q5 1 batched bf16 can implement).
pub fn baracuda_kernels_mmvq_q5_1_batched_bf16_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q8_0, f32. # Safety: as Q4_0.
pub fn baracuda_kernels_mmvq_q8_0_batched_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q8_0_batched_can_implement` (baracuda kernels mmvq q8 0 batched can implement).
pub fn baracuda_kernels_mmvq_q8_0_batched_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q8_0, f16. # Safety: as Q4_0.
pub fn baracuda_kernels_mmvq_q8_0_batched_f16_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q8_0_batched_f16_can_implement` (baracuda kernels mmvq q8 0 batched f16 can implement).
pub fn baracuda_kernels_mmvq_q8_0_batched_f16_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q8_0, bf16. # Safety: as Q4_0.
pub fn baracuda_kernels_mmvq_q8_0_batched_bf16_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q8_0_batched_bf16_can_implement` (baracuda kernels mmvq q8 0 batched bf16 can implement).
pub fn baracuda_kernels_mmvq_q8_0_batched_bf16_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
// ---- Quant: k-quants (256-elt blocks) --------------------------------
/// Batched MMVQ — Q2_K, f32. # Safety: as Q4_0, ncols mul of 256.
pub fn baracuda_kernels_mmvq_q2_K_batched_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q2_K_batched_can_implement` (baracuda kernels mmvq q2 k batched can implement).
pub fn baracuda_kernels_mmvq_q2_K_batched_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q2_K, f16. # Safety: as Q2_K f32.
pub fn baracuda_kernels_mmvq_q2_K_batched_f16_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q2_K_batched_f16_can_implement` (baracuda kernels mmvq q2 k batched f16 can implement).
pub fn baracuda_kernels_mmvq_q2_K_batched_f16_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q2_K, bf16. # Safety: as Q2_K f32.
pub fn baracuda_kernels_mmvq_q2_K_batched_bf16_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q2_K_batched_bf16_can_implement` (baracuda kernels mmvq q2 k batched bf16 can implement).
pub fn baracuda_kernels_mmvq_q2_K_batched_bf16_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q3_K, f32. # Safety: as Q2_K.
pub fn baracuda_kernels_mmvq_q3_K_batched_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q3_K_batched_can_implement` (baracuda kernels mmvq q3 k batched can implement).
pub fn baracuda_kernels_mmvq_q3_K_batched_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q3_K, f16. # Safety: as Q2_K.
pub fn baracuda_kernels_mmvq_q3_K_batched_f16_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q3_K_batched_f16_can_implement` (baracuda kernels mmvq q3 k batched f16 can implement).
pub fn baracuda_kernels_mmvq_q3_K_batched_f16_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q3_K, bf16. # Safety: as Q2_K.
pub fn baracuda_kernels_mmvq_q3_K_batched_bf16_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q3_K_batched_bf16_can_implement` (baracuda kernels mmvq q3 k batched bf16 can implement).
pub fn baracuda_kernels_mmvq_q3_K_batched_bf16_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q4_K, f32. # Safety: as Q2_K.
pub fn baracuda_kernels_mmvq_q4_K_batched_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q4_K_batched_can_implement` (baracuda kernels mmvq q4 k batched can implement).
pub fn baracuda_kernels_mmvq_q4_K_batched_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q4_K, f16. # Safety: as Q2_K.
pub fn baracuda_kernels_mmvq_q4_K_batched_f16_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q4_K_batched_f16_can_implement` (baracuda kernels mmvq q4 k batched f16 can implement).
pub fn baracuda_kernels_mmvq_q4_K_batched_f16_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q4_K, bf16. # Safety: as Q2_K.
pub fn baracuda_kernels_mmvq_q4_K_batched_bf16_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q4_K_batched_bf16_can_implement` (baracuda kernels mmvq q4 k batched bf16 can implement).
pub fn baracuda_kernels_mmvq_q4_K_batched_bf16_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q5_K, f32. # Safety: as Q2_K.
pub fn baracuda_kernels_mmvq_q5_K_batched_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q5_K_batched_can_implement` (baracuda kernels mmvq q5 k batched can implement).
pub fn baracuda_kernels_mmvq_q5_K_batched_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q5_K, f16. # Safety: as Q2_K.
pub fn baracuda_kernels_mmvq_q5_K_batched_f16_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q5_K_batched_f16_can_implement` (baracuda kernels mmvq q5 k batched f16 can implement).
pub fn baracuda_kernels_mmvq_q5_K_batched_f16_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q5_K, bf16. # Safety: as Q2_K.
pub fn baracuda_kernels_mmvq_q5_K_batched_bf16_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q5_K_batched_bf16_can_implement` (baracuda kernels mmvq q5 k batched bf16 can implement).
pub fn baracuda_kernels_mmvq_q5_K_batched_bf16_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q6_K, f32. # Safety: as Q2_K.
pub fn baracuda_kernels_mmvq_q6_K_batched_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q6_K_batched_can_implement` (baracuda kernels mmvq q6 k batched can implement).
pub fn baracuda_kernels_mmvq_q6_K_batched_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q6_K, f16. # Safety: as Q2_K.
pub fn baracuda_kernels_mmvq_q6_K_batched_f16_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q6_K_batched_f16_can_implement` (baracuda kernels mmvq q6 k batched f16 can implement).
pub fn baracuda_kernels_mmvq_q6_K_batched_f16_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q6_K, bf16. # Safety: as Q2_K.
pub fn baracuda_kernels_mmvq_q6_K_batched_bf16_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q6_K_batched_bf16_can_implement` (baracuda kernels mmvq q6 k batched bf16 can implement).
pub fn baracuda_kernels_mmvq_q6_K_batched_bf16_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q8_K (bespoke), f32. # Safety: as Q2_K.
pub fn baracuda_kernels_mmvq_q8_K_batched_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q8_K_batched_can_implement` (baracuda kernels mmvq q8 k batched can implement).
pub fn baracuda_kernels_mmvq_q8_K_batched_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q8_K (bespoke), f16. # Safety: as Q2_K.
pub fn baracuda_kernels_mmvq_q8_K_batched_f16_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q8_K_batched_f16_can_implement` (baracuda kernels mmvq q8 k batched f16 can implement).
pub fn baracuda_kernels_mmvq_q8_K_batched_f16_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMVQ — Q8_K (bespoke), bf16. # Safety: as Q2_K.
pub fn baracuda_kernels_mmvq_q8_K_batched_bf16_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_q8_K_batched_bf16_can_implement` (baracuda kernels mmvq q8 k batched bf16 can implement).
pub fn baracuda_kernels_mmvq_q8_K_batched_bf16_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
// ---- Pure FP (non-quant) batched MMVQ --------------------------------
/// Batched MMV (non-quant) — f32 weights + activation + output.
/// # Safety: device-resident pointers; valid stream; workspace ≥ m_total*4.
pub fn baracuda_kernels_mmvq_batched_f32_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_batched_f32_can_implement` (baracuda kernels mmvq batched f32 can implement).
pub fn baracuda_kernels_mmvq_batched_f32_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMV (non-quant) — f16. # Safety: as f32.
pub fn baracuda_kernels_mmvq_batched_f16_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_batched_f16_can_implement` (baracuda kernels mmvq batched f16 can implement).
pub fn baracuda_kernels_mmvq_batched_f16_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
/// Batched MMV (non-quant) — bf16. # Safety: as f32.
pub fn baracuda_kernels_mmvq_batched_bf16_run(
n_experts: i32, n_rows_per_expert: i32, n_cols: i32,
weights: *const c_void, activations: *const c_void,
sorted_token_ids: *const i32, expert_offsets: *const i32,
topk_weights: *const f32, output: *mut c_void, top_k: i32,
workspace: *mut c_void, workspace_bytes: usize, stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_mmvq_batched_bf16_can_implement` (baracuda kernels mmvq batched bf16 can implement).
pub fn baracuda_kernels_mmvq_batched_bf16_can_implement(
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *const c_void,
top_k: i32,
) -> i32;
}
// ============================================================================
// Mixture-of-Experts forward — Phase 8 Milestone 8.5 (Category V)
// ============================================================================
//
// Vendored from `attention.rs` via fuel-cuda-kernels. See
// `kernels/include/baracuda_moe.cuh` and
// `LICENSE-thirdparty.md` for lineage notes.
//
// ## Phase 20.2 — Fuel-replacement contract (2026-05-25)
//
// These symbols are the canonical baracuda MoE surface. Fuel's
// `fuel-cuda-kernels/src/moe/` is retired in favour of direct calls to
// the symbols below. Recon (2026-05-25) confirmed Fuel's source has
// not changed since the original Phase 8.5 vendor (single commit
// touching those paths), so baracuda's kernel bodies are already
// current — no refresh diff to apply.
//
// Symbol-shape vs Fuel's pre-20.2 `ffi.rs`:
//
// | Fuel (retired) | baracuda (canonical) |
// | ---------------------- | ----------------------------------------------- |
// | `moe_gemm_wmma` | `baracuda_kernels_moe_wmma_{f16,bf16}_run` |
// | `moe_gemm_gguf` | `baracuda_kernels_moe_scalar_gguf_run` |
// | `moe_gemm_gguf_prefill`| `baracuda_kernels_moe_wmma_gguf_{f16,bf16}_run` |
//
// Baracuda collapses activation-dtype into the symbol name (the
// project-wide FFI convention) and adds the
// `(workspace, workspace_bytes)` pair before `stream` (also project-
// wide). Fuel callers pass `(nullptr, 0)` for workspace and otherwise
// have a 1:1 parameter mapping. See
// `crates/baracuda-kernels/tests/moe_ffi_direct_smoke.rs` for the
// reference direct-FFI call pattern.
//
// Three variants:
// * `_moe_scalar_gguf_*_run` — f32 activations, GGUF-packed expert
// weights. Scalar dispatch (no tensor cores). Stages activations
// through a q8_1 intermediate allocated internally.
// * `_moe_wmma_<f16|bf16>_run` — FP activations + FP weights, sm_70+
// WMMA tensor cores.
// * `_moe_wmma_gguf_<f16|bf16>_run` — FP activations + GGUF-packed
// expert weights, sm_70+ WMMA tensor cores. The production hot
// path for quantized LLM inference.
//
// `gguf_dtype` argument for the GGUF variants follows Fuel's
// discriminant numbering (NOT baracuda's `GgufBlockFormat::repr`):
// 0 = Q8_0, 1 = Q4_K, 2 = Q2_K, 3 = Q3_K, 4 = Q5_K, 5 = Q6_K.
//
// `is_prefill` (WMMA-only variant) selects between prefill geometry
// (M_tile=16, N_tile=16) and decode geometry (M_tile=8, N_tile=32).
//
// The WMMA paths require caller-allocated scratch buffers
// `expert_counts[num_experts]` and `expert_offsets[num_experts + 1]`
// (both i32, device-resident).
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
/// MoE forward — scalar dispatch path on GGUF-packed expert weights.
/// f32 activations in, f32 output out.
///
/// `gguf_dtype` discriminant (Fuel-convention, not
/// `GgufBlockFormat::repr`):
/// `0 = Q8_0`, `1 = Q4_K`, `2 = Q2_K`, `3 = Q3_K`, `4 = Q5_K`, `5 = Q6_K`.
///
/// # Safety
/// All pointer args must be device-resident with the documented
/// shapes; `stream` must be a valid CUDA stream pointer.
pub fn baracuda_kernels_moe_scalar_gguf_run(
inputs: *const c_void, // f32 [size_m_input, size_k]
weights: *const c_void, // packed GGUF [num_experts, size_n, size_k]
sorted_token_ids: *const i32, // [size_m]
expert_ids: *const i32, // [size_m]
topk_weights: *const f32, // [size_m] or null
outputs: *mut c_void, // f32 [size_m_input, size_n]
num_experts: i32,
topk: i32,
size_m: i32,
size_n: i32,
size_k: i32,
gguf_dtype: i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_moe_scalar_gguf_can_implement` (baracuda kernels moe scalar gguf can implement).
pub fn baracuda_kernels_moe_scalar_gguf_can_implement(
inputs: *const c_void,
weights: *const c_void,
sorted_token_ids: *const i32,
expert_ids: *const i32,
topk_weights: *const f32,
outputs: *const c_void,
num_experts: i32,
topk: i32,
size_m: i32,
size_n: i32,
size_k: i32,
gguf_dtype: i32,
) -> i32;
/// MoE forward — WMMA FP weights, f16 activations + weights, f16 output.
/// Output buffer must be zero-initialized by the caller when
/// `topk_weights == null` and `topk > 1` (multiple writes per row).
/// # Safety: as `moe_scalar_gguf_run`.
pub fn baracuda_kernels_moe_wmma_f16_run(
input: *const c_void,
weights: *const c_void,
sorted_token_ids: *const i32,
expert_ids: *const i32,
topk_weights: *const f32,
output: *mut c_void,
expert_counts: *mut i32, // prealloc [num_experts]
expert_offsets: *mut i32, // prealloc [num_experts + 1]
num_experts: i32,
topk: i32,
size_m: i32,
size_n: i32,
size_k: i32,
is_prefill: i32, // 0 = decode (M=8, N=32); !=0 = prefill (M=16, N=16)
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_moe_wmma_f16_can_implement` (baracuda kernels moe wmma f16 can implement).
pub fn baracuda_kernels_moe_wmma_f16_can_implement(
num_experts: i32, topk: i32, size_m: i32, size_n: i32, size_k: i32, is_prefill: i32,
) -> i32;
/// MoE forward — WMMA FP weights, bf16 activations + weights, bf16 output.
/// # Safety: as `moe_wmma_f16_run`.
pub fn baracuda_kernels_moe_wmma_bf16_run(
input: *const c_void,
weights: *const c_void,
sorted_token_ids: *const i32,
expert_ids: *const i32,
topk_weights: *const f32,
output: *mut c_void,
expert_counts: *mut i32,
expert_offsets: *mut i32,
num_experts: i32,
topk: i32,
size_m: i32,
size_n: i32,
size_k: i32,
is_prefill: i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_moe_wmma_bf16_can_implement` (baracuda kernels moe wmma bf16 can implement).
pub fn baracuda_kernels_moe_wmma_bf16_can_implement(
num_experts: i32, topk: i32, size_m: i32, size_n: i32, size_k: i32, is_prefill: i32,
) -> i32;
/// MoE forward — WMMA + GGUF combined path. f16 activations,
/// GGUF-packed weights, f32 output.
///
/// `gguf_dtype` discriminant (Fuel-convention):
/// `0 = Q8_0`, `1 = Q4_K`, `2 = Q2_K`, `3 = Q3_K`, `4 = Q5_K`, `5 = Q6_K`.
/// # Safety: as `moe_wmma_f16_run`.
pub fn baracuda_kernels_moe_wmma_gguf_f16_run(
input: *const c_void,
weights: *const c_void, // packed GGUF bytes
sorted_token_ids: *const i32,
expert_ids: *const i32,
topk_weights: *const f32,
output: *mut c_void, // f32
expert_counts: *mut i32,
expert_offsets: *mut i32,
num_experts: i32,
topk: i32,
size_m: i32,
size_n: i32,
size_k: i32,
gguf_dtype: i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_moe_wmma_gguf_f16_can_implement` (baracuda kernels moe wmma gguf f16 can implement).
pub fn baracuda_kernels_moe_wmma_gguf_f16_can_implement(
num_experts: i32, topk: i32, size_m: i32, size_n: i32, size_k: i32, gguf_dtype: i32,
) -> i32;
/// MoE forward — WMMA + GGUF combined path, bf16 activations.
/// # Safety: as `moe_wmma_gguf_f16_run`.
pub fn baracuda_kernels_moe_wmma_gguf_bf16_run(
input: *const c_void,
weights: *const c_void,
sorted_token_ids: *const i32,
expert_ids: *const i32,
topk_weights: *const f32,
output: *mut c_void,
expert_counts: *mut i32,
expert_offsets: *mut i32,
num_experts: i32,
topk: i32,
size_m: i32,
size_n: i32,
size_k: i32,
gguf_dtype: i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_moe_wmma_gguf_bf16_can_implement` (baracuda kernels moe wmma gguf bf16 can implement).
pub fn baracuda_kernels_moe_wmma_gguf_bf16_can_implement(
num_experts: i32, topk: i32, size_m: i32, size_n: i32, size_k: i32, gguf_dtype: i32,
) -> i32;
}
// ============================================================================
// Image / spatial transforms — Phase 9 Category T
// ============================================================================
//
// Trailblazer set: interpolate (bilinear 2D), grid_sample (2D),
// affine_grid (2D), pixel_shuffle / pixel_unshuffle, roi_align, roi_pool,
// nms. NCHW layout throughout. f32 + f64 for math-bearing ops;
// pixel_shuffle / pixel_unshuffle also wire f16 + bf16.
//
// BW ops that scatter via atomicAdd (`interpolate_bilinear_2d_backward`,
// `grid_sample_2d_backward`, `roi_align_backward`, `roi_pool_backward`)
// require caller to pre-zero the output gradient buffers.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
// ---- interpolate (bilinear 2D) ----
//
// Phase 21 signature: the three trailing params before `stream`
// (`align_corners`, `scale_h_factor`, `scale_w_factor`) are new.
// `align_corners=0` matches `F.interpolate` PyTorch default;
// nonzero matches `nn.Upsample(align_corners=True)`. Scale-factor
// sentinel `0.0` means "derive from sizes"; any nonzero value is
// interpreted as PyTorch-style SCALE (output_size / input_size)
// and overrides the size-derived ratio.
//
// Dtype coverage: f32, f64, f16, bf16 (FW + BW each).
/// `interpolate(x, mode='bilinear')` FW, f32.
/// `input`: `[N, C, IH, IW]`; `output`: `[N, C, OH, OW]`. NCHW.
/// `align_corners`: 0 = false (PyTorch default), nonzero = true.
/// `scale_h_factor` / `scale_w_factor`: 0.0 = derive from
/// sizes; nonzero = use as SCALE override.
/// # Safety: all pointers must be live device memory; `stream` valid.
pub fn baracuda_kernels_interpolate_bilinear_2d_f32_run(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
input: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
align_corners: i32,
scale_h_factor: f64,
scale_w_factor: f64,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_interpolate_bilinear_2d_f32_can_implement` (baracuda kernels interpolate bilinear 2d f32 can implement).
pub fn baracuda_kernels_interpolate_bilinear_2d_f32_can_implement(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
input: *const c_void,
output: *const c_void,
workspace: *const c_void,
workspace_bytes: usize,
align_corners: i32,
scale_h_factor: f64,
scale_w_factor: f64,
) -> i32;
/// `interpolate_bilinear_2d` FW, f64. # Safety: as f32.
pub fn baracuda_kernels_interpolate_bilinear_2d_f64_run(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
input: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
align_corners: i32,
scale_h_factor: f64,
scale_w_factor: f64,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_interpolate_bilinear_2d_f64_can_implement` (baracuda kernels interpolate bilinear 2d f64 can implement).
pub fn baracuda_kernels_interpolate_bilinear_2d_f64_can_implement(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
input: *const c_void,
output: *const c_void,
workspace: *const c_void,
workspace_bytes: usize,
align_corners: i32,
scale_h_factor: f64,
scale_w_factor: f64,
) -> i32;
/// `interpolate_bilinear_2d` FW, f16 (half). Cast-at-read / f32
/// accumulator / cast-at-write. # Safety: as f32.
pub fn baracuda_kernels_interpolate_bilinear_2d_f16_run(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
input: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
align_corners: i32,
scale_h_factor: f64,
scale_w_factor: f64,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_interpolate_bilinear_2d_f16_can_implement` (baracuda kernels interpolate bilinear 2d f16 can implement).
pub fn baracuda_kernels_interpolate_bilinear_2d_f16_can_implement(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
input: *const c_void,
output: *const c_void,
workspace: *const c_void,
workspace_bytes: usize,
align_corners: i32,
scale_h_factor: f64,
scale_w_factor: f64,
) -> i32;
/// `interpolate_bilinear_2d` FW, bf16. Cast-at-read / f32
/// accumulator / cast-at-write. # Safety: as f32.
pub fn baracuda_kernels_interpolate_bilinear_2d_bf16_run(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
input: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
align_corners: i32,
scale_h_factor: f64,
scale_w_factor: f64,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_interpolate_bilinear_2d_bf16_can_implement` (baracuda kernels interpolate bilinear 2d bf16 can implement).
pub fn baracuda_kernels_interpolate_bilinear_2d_bf16_can_implement(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
input: *const c_void,
output: *const c_void,
workspace: *const c_void,
workspace_bytes: usize,
align_corners: i32,
scale_h_factor: f64,
scale_w_factor: f64,
) -> i32;
/// `interpolate_bilinear_2d` BW, f32. Caller pre-zeros `dinput`.
/// # Safety: as FW.
pub fn baracuda_kernels_interpolate_bilinear_2d_backward_f32_run(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
dout: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
align_corners: i32,
scale_h_factor: f64,
scale_w_factor: f64,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_interpolate_bilinear_2d_backward_f32_can_implement` (baracuda kernels interpolate bilinear 2d backward f32 can implement).
pub fn baracuda_kernels_interpolate_bilinear_2d_backward_f32_can_implement(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
dout: *const c_void,
dinput: *const c_void,
workspace: *const c_void,
workspace_bytes: usize,
align_corners: i32,
scale_h_factor: f64,
scale_w_factor: f64,
) -> i32;
/// `interpolate_bilinear_2d` BW, f64. # Safety: as f32 BW.
pub fn baracuda_kernels_interpolate_bilinear_2d_backward_f64_run(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
dout: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
align_corners: i32,
scale_h_factor: f64,
scale_w_factor: f64,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_interpolate_bilinear_2d_backward_f64_can_implement` (baracuda kernels interpolate bilinear 2d backward f64 can implement).
pub fn baracuda_kernels_interpolate_bilinear_2d_backward_f64_can_implement(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
dout: *const c_void,
dinput: *const c_void,
workspace: *const c_void,
workspace_bytes: usize,
align_corners: i32,
scale_h_factor: f64,
scale_w_factor: f64,
) -> i32;
/// `interpolate_bilinear_2d` BW, f16. Caller pre-zeros `dinput`.
/// `atomicCAS`-based half atomic add. # Safety: as f32 BW.
pub fn baracuda_kernels_interpolate_bilinear_2d_backward_f16_run(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
dout: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
align_corners: i32,
scale_h_factor: f64,
scale_w_factor: f64,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_interpolate_bilinear_2d_backward_f16_can_implement` (baracuda kernels interpolate bilinear 2d backward f16 can implement).
pub fn baracuda_kernels_interpolate_bilinear_2d_backward_f16_can_implement(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
dout: *const c_void,
dinput: *const c_void,
workspace: *const c_void,
workspace_bytes: usize,
align_corners: i32,
scale_h_factor: f64,
scale_w_factor: f64,
) -> i32;
/// `interpolate_bilinear_2d` BW, bf16. Caller pre-zeros `dinput`.
/// `atomicCAS`-based bf16 atomic add. # Safety: as f32 BW.
pub fn baracuda_kernels_interpolate_bilinear_2d_backward_bf16_run(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
dout: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
align_corners: i32,
scale_h_factor: f64,
scale_w_factor: f64,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_interpolate_bilinear_2d_backward_bf16_can_implement` (baracuda kernels interpolate bilinear 2d backward bf16 can implement).
pub fn baracuda_kernels_interpolate_bilinear_2d_backward_bf16_can_implement(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
dout: *const c_void,
dinput: *const c_void,
workspace: *const c_void,
workspace_bytes: usize,
align_corners: i32,
scale_h_factor: f64,
scale_w_factor: f64,
) -> i32;
// ---- grid_sample (2D, bilinear, zeros pad, align_corners=false) ----
/// `grid_sample(input, grid)` FW, f32. `grid`: `[N, OH, OW, 2]` with
/// (x, y) normalized in [-1, 1]. # Safety: as `interpolate_*`.
pub fn baracuda_kernels_grid_sample_2d_f32_run(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
input: *const c_void,
grid: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_grid_sample_2d_f32_can_implement` (baracuda kernels grid sample 2d f32 can implement).
pub fn baracuda_kernels_grid_sample_2d_f32_can_implement(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
input: *const c_void,
grid: *const c_void,
output: *const c_void,
) -> i32;
/// `grid_sample_2d` FW, f64. # Safety: as f32.
pub fn baracuda_kernels_grid_sample_2d_f64_run(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
input: *const c_void,
grid: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_grid_sample_2d_f64_can_implement` (baracuda kernels grid sample 2d f64 can implement).
pub fn baracuda_kernels_grid_sample_2d_f64_can_implement(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
input: *const c_void,
grid: *const c_void,
output: *const c_void,
) -> i32;
/// `grid_sample_2d` BW, f32. Caller pre-zeros `dinput` and `dgrid`.
/// `dgrid`: `[N, OH, OW, 2]`. # Safety: as FW.
pub fn baracuda_kernels_grid_sample_2d_backward_f32_run(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
dout: *const c_void,
input: *const c_void,
grid: *const c_void,
dinput: *mut c_void,
dgrid: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_grid_sample_2d_backward_f32_can_implement` (baracuda kernels grid sample 2d backward f32 can implement).
pub fn baracuda_kernels_grid_sample_2d_backward_f32_can_implement(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
dout: *const c_void,
input: *const c_void,
grid: *const c_void,
dinput: *const c_void,
dgrid: *const c_void,
) -> i32;
/// `grid_sample_2d` BW, f64. # Safety: as f32 BW.
pub fn baracuda_kernels_grid_sample_2d_backward_f64_run(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
dout: *const c_void,
input: *const c_void,
grid: *const c_void,
dinput: *mut c_void,
dgrid: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_grid_sample_2d_backward_f64_can_implement` (baracuda kernels grid sample 2d backward f64 can implement).
pub fn baracuda_kernels_grid_sample_2d_backward_f64_can_implement(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
dout: *const c_void,
input: *const c_void,
grid: *const c_void,
dinput: *const c_void,
dgrid: *const c_void,
) -> i32;
// ---- affine_grid (2D) ----
/// `affine_grid(theta, size)` — produce `[N, OH, OW, 2]` grid from
/// `theta: [N, 2, 3]`. f32. # Safety: as above.
pub fn baracuda_kernels_affine_grid_2d_f32_run(
N: i32, OH: i32, OW: i32,
theta: *const c_void,
grid: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_affine_grid_2d_f32_can_implement` (baracuda kernels affine grid 2d f32 can implement).
pub fn baracuda_kernels_affine_grid_2d_f32_can_implement(
N: i32, OH: i32, OW: i32,
theta: *const c_void,
grid: *const c_void,
) -> i32;
/// `affine_grid_2d`, f64. # Safety: as f32.
pub fn baracuda_kernels_affine_grid_2d_f64_run(
N: i32, OH: i32, OW: i32,
theta: *const c_void,
grid: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_affine_grid_2d_f64_can_implement` (baracuda kernels affine grid 2d f64 can implement).
pub fn baracuda_kernels_affine_grid_2d_f64_can_implement(
N: i32, OH: i32, OW: i32,
theta: *const c_void,
grid: *const c_void,
) -> i32;
// ---- pixel_shuffle / pixel_unshuffle ----
/// `pixel_shuffle(x, r)` — `[N, C·r², H, W] → [N, C, H·r, W·r]`.
/// f32. # Safety: as above.
pub fn baracuda_kernels_pixel_shuffle_f32_run(
N: i32, C: i32, H: i32, W: i32, r: i32,
input: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_pixel_shuffle_f32_can_implement` (baracuda kernels pixel shuffle f32 can implement).
pub fn baracuda_kernels_pixel_shuffle_f32_can_implement(
N: i32, C: i32, H: i32, W: i32, r: i32,
input: *const c_void,
output: *const c_void,
) -> i32;
/// `pixel_shuffle`, f64. # Safety: as f32.
pub fn baracuda_kernels_pixel_shuffle_f64_run(
N: i32, C: i32, H: i32, W: i32, r: i32,
input: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_pixel_shuffle_f64_can_implement` (baracuda kernels pixel shuffle f64 can implement).
pub fn baracuda_kernels_pixel_shuffle_f64_can_implement(
N: i32, C: i32, H: i32, W: i32, r: i32,
input: *const c_void,
output: *const c_void,
) -> i32;
/// `pixel_shuffle`, f16. # Safety: as f32.
pub fn baracuda_kernels_pixel_shuffle_f16_run(
N: i32, C: i32, H: i32, W: i32, r: i32,
input: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_pixel_shuffle_f16_can_implement` (baracuda kernels pixel shuffle f16 can implement).
pub fn baracuda_kernels_pixel_shuffle_f16_can_implement(
N: i32, C: i32, H: i32, W: i32, r: i32,
input: *const c_void,
output: *const c_void,
) -> i32;
/// `pixel_shuffle`, bf16. # Safety: as f32.
pub fn baracuda_kernels_pixel_shuffle_bf16_run(
N: i32, C: i32, H: i32, W: i32, r: i32,
input: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_pixel_shuffle_bf16_can_implement` (baracuda kernels pixel shuffle bf16 can implement).
pub fn baracuda_kernels_pixel_shuffle_bf16_can_implement(
N: i32, C: i32, H: i32, W: i32, r: i32,
input: *const c_void,
output: *const c_void,
) -> i32;
/// `pixel_unshuffle(x, r)` — `[N, C, H·r, W·r] → [N, C·r², H, W]`.
/// Inverse of pixel_shuffle (and each is the other's BW). f32.
/// # Safety: as above.
pub fn baracuda_kernels_pixel_unshuffle_f32_run(
N: i32, C: i32, H: i32, W: i32, r: i32,
input: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_pixel_unshuffle_f32_can_implement` (baracuda kernels pixel unshuffle f32 can implement).
pub fn baracuda_kernels_pixel_unshuffle_f32_can_implement(
N: i32, C: i32, H: i32, W: i32, r: i32,
input: *const c_void,
output: *const c_void,
) -> i32;
/// `pixel_unshuffle`, f64. # Safety: as f32.
pub fn baracuda_kernels_pixel_unshuffle_f64_run(
N: i32, C: i32, H: i32, W: i32, r: i32,
input: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_pixel_unshuffle_f64_can_implement` (baracuda kernels pixel unshuffle f64 can implement).
pub fn baracuda_kernels_pixel_unshuffle_f64_can_implement(
N: i32, C: i32, H: i32, W: i32, r: i32,
input: *const c_void,
output: *const c_void,
) -> i32;
/// `pixel_unshuffle`, f16. # Safety: as f32.
pub fn baracuda_kernels_pixel_unshuffle_f16_run(
N: i32, C: i32, H: i32, W: i32, r: i32,
input: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_pixel_unshuffle_f16_can_implement` (baracuda kernels pixel unshuffle f16 can implement).
pub fn baracuda_kernels_pixel_unshuffle_f16_can_implement(
N: i32, C: i32, H: i32, W: i32, r: i32,
input: *const c_void,
output: *const c_void,
) -> i32;
/// `pixel_unshuffle`, bf16. # Safety: as f32.
pub fn baracuda_kernels_pixel_unshuffle_bf16_run(
N: i32, C: i32, H: i32, W: i32, r: i32,
input: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_pixel_unshuffle_bf16_can_implement` (baracuda kernels pixel unshuffle bf16 can implement).
pub fn baracuda_kernels_pixel_unshuffle_bf16_can_implement(
N: i32, C: i32, H: i32, W: i32, r: i32,
input: *const c_void,
output: *const c_void,
) -> i32;
// ---- roi_align / roi_pool ----
/// `roi_align`, f32. `rois`: `[num_rois, 5]` (batch_idx, x1, y1, x2, y2)
/// in INPUT-pixel coords (scaled by `spatial_scale` inside the kernel).
/// `sampling_ratio == 0` selects adaptive sampling.
/// `aligned == 0` is PyTorch's pre-0.6 convention.
/// # Safety: as above.
pub fn baracuda_kernels_roi_align_f32_run(
N: i32, C: i32, H: i32, W: i32,
num_rois: i32, pooled_h: i32, pooled_w: i32,
spatial_scale: f32, sampling_ratio: i32, aligned: i32,
input: *const c_void,
rois: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_roi_align_f32_can_implement` (baracuda kernels roi align f32 can implement).
pub fn baracuda_kernels_roi_align_f32_can_implement(
N: i32, C: i32, H: i32, W: i32,
num_rois: i32, pooled_h: i32, pooled_w: i32,
spatial_scale: f32, sampling_ratio: i32, aligned: i32,
input: *const c_void,
rois: *const c_void,
output: *const c_void,
) -> i32;
/// `roi_align`, f64. # Safety: as f32.
pub fn baracuda_kernels_roi_align_f64_run(
N: i32, C: i32, H: i32, W: i32,
num_rois: i32, pooled_h: i32, pooled_w: i32,
spatial_scale: f32, sampling_ratio: i32, aligned: i32,
input: *const c_void,
rois: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_roi_align_f64_can_implement` (baracuda kernels roi align f64 can implement).
pub fn baracuda_kernels_roi_align_f64_can_implement(
N: i32, C: i32, H: i32, W: i32,
num_rois: i32, pooled_h: i32, pooled_w: i32,
spatial_scale: f32, sampling_ratio: i32, aligned: i32,
input: *const c_void,
rois: *const c_void,
output: *const c_void,
) -> i32;
/// `roi_align` BW, f32. Caller pre-zeros `dinput`. # Safety: as FW.
pub fn baracuda_kernels_roi_align_backward_f32_run(
N: i32, C: i32, H: i32, W: i32,
num_rois: i32, pooled_h: i32, pooled_w: i32,
spatial_scale: f32, sampling_ratio: i32, aligned: i32,
dout: *const c_void,
rois: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_roi_align_backward_f32_can_implement` (baracuda kernels roi align backward f32 can implement).
pub fn baracuda_kernels_roi_align_backward_f32_can_implement(
N: i32, C: i32, H: i32, W: i32,
num_rois: i32, pooled_h: i32, pooled_w: i32,
spatial_scale: f32, sampling_ratio: i32, aligned: i32,
dout: *const c_void,
rois: *const c_void,
dinput: *const c_void,
) -> i32;
/// `roi_align` BW, f64. # Safety: as f32 BW.
pub fn baracuda_kernels_roi_align_backward_f64_run(
N: i32, C: i32, H: i32, W: i32,
num_rois: i32, pooled_h: i32, pooled_w: i32,
spatial_scale: f32, sampling_ratio: i32, aligned: i32,
dout: *const c_void,
rois: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_roi_align_backward_f64_can_implement` (baracuda kernels roi align backward f64 can implement).
pub fn baracuda_kernels_roi_align_backward_f64_can_implement(
N: i32, C: i32, H: i32, W: i32,
num_rois: i32, pooled_h: i32, pooled_w: i32,
spatial_scale: f32, sampling_ratio: i32, aligned: i32,
dout: *const c_void,
rois: *const c_void,
dinput: *const c_void,
) -> i32;
/// `roi_pool`, f32. Writes `output` AND `argmax` (i32 linear
/// plane-relative index per output cell; `-1` for empty bins).
/// # Safety: as `roi_align`.
pub fn baracuda_kernels_roi_pool_f32_run(
N: i32, C: i32, H: i32, W: i32,
num_rois: i32, pooled_h: i32, pooled_w: i32,
spatial_scale: f32,
input: *const c_void,
rois: *const c_void,
output: *mut c_void,
argmax: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_roi_pool_f32_can_implement` (baracuda kernels roi pool f32 can implement).
pub fn baracuda_kernels_roi_pool_f32_can_implement(
N: i32, C: i32, H: i32, W: i32,
num_rois: i32, pooled_h: i32, pooled_w: i32,
spatial_scale: f32,
input: *const c_void,
rois: *const c_void,
output: *const c_void,
argmax: *const c_void,
) -> i32;
/// `roi_pool`, f64. # Safety: as f32.
pub fn baracuda_kernels_roi_pool_f64_run(
N: i32, C: i32, H: i32, W: i32,
num_rois: i32, pooled_h: i32, pooled_w: i32,
spatial_scale: f32,
input: *const c_void,
rois: *const c_void,
output: *mut c_void,
argmax: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_roi_pool_f64_can_implement` (baracuda kernels roi pool f64 can implement).
pub fn baracuda_kernels_roi_pool_f64_can_implement(
N: i32, C: i32, H: i32, W: i32,
num_rois: i32, pooled_h: i32, pooled_w: i32,
spatial_scale: f32,
input: *const c_void,
rois: *const c_void,
output: *const c_void,
argmax: *const c_void,
) -> i32;
/// `roi_pool` BW, f32. Caller pre-zeros `dinput`. # Safety: as FW.
pub fn baracuda_kernels_roi_pool_backward_f32_run(
N: i32, C: i32, H: i32, W: i32,
num_rois: i32, pooled_h: i32, pooled_w: i32,
dout: *const c_void,
rois: *const c_void,
argmax: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_roi_pool_backward_f32_can_implement` (baracuda kernels roi pool backward f32 can implement).
pub fn baracuda_kernels_roi_pool_backward_f32_can_implement(
N: i32, C: i32, H: i32, W: i32,
num_rois: i32, pooled_h: i32, pooled_w: i32,
dout: *const c_void,
rois: *const c_void,
argmax: *const c_void,
dinput: *const c_void,
) -> i32;
/// `roi_pool` BW, f64. # Safety: as f32 BW.
pub fn baracuda_kernels_roi_pool_backward_f64_run(
N: i32, C: i32, H: i32, W: i32,
num_rois: i32, pooled_h: i32, pooled_w: i32,
dout: *const c_void,
rois: *const c_void,
argmax: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_roi_pool_backward_f64_can_implement` (baracuda kernels roi pool backward f64 can implement).
pub fn baracuda_kernels_roi_pool_backward_f64_can_implement(
N: i32, C: i32, H: i32, W: i32,
num_rois: i32, pooled_h: i32, pooled_w: i32,
dout: *const c_void,
rois: *const c_void,
argmax: *const c_void,
dinput: *const c_void,
) -> i32;
// ---- nms (non-max suppression) ----
/// `nms(boxes, iou_thresh)`. Caller supplies boxes pre-sorted by
/// score, descending. `boxes`: `[num_boxes, 4]` (x1, y1, x2, y2).
/// `keep_mask`: `[num_boxes]` u8 (0 / 1); `count_out`: single i32.
/// f32. # Safety: as above.
pub fn baracuda_kernels_nms_f32_run(
num_boxes: i32,
iou_thresh: f32,
boxes: *const c_void,
keep_mask: *mut c_void,
count_out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_nms_f32_can_implement` (baracuda kernels nms f32 can implement).
pub fn baracuda_kernels_nms_f32_can_implement(
num_boxes: i32,
iou_thresh: f32,
boxes: *const c_void,
keep_mask: *const c_void,
count_out: *const c_void,
) -> i32;
/// `nms`, f64. # Safety: as f32.
pub fn baracuda_kernels_nms_f64_run(
num_boxes: i32,
iou_thresh: f32,
boxes: *const c_void,
keep_mask: *mut c_void,
count_out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_nms_f64_can_implement` (baracuda kernels nms f64 can implement).
pub fn baracuda_kernels_nms_f64_can_implement(
num_boxes: i32,
iou_thresh: f32,
boxes: *const c_void,
keep_mask: *const c_void,
count_out: *const c_void,
) -> i32;
}
// ============================================================================
// Phase 9 Category O — sort / argsort / msort / topk / kthvalue / unique /
// histogram / bincount / searchsorted
// ============================================================================
//
// Sort / argsort / msort: block-bitonic, one CUDA block per row.
// Trailblazer caps `row_len ≤ 1024`. `descending` is `0` = ascending,
// `1` = descending. `msort` is the stable variant (tie-break on
// original index).
//
// FW signature (sort / msort): emits BOTH sorted values AND sorted
// indices in one launch (saved-indices contract — BW reads them):
// (batch, row_len, descending, x, y_vals, y_idx, ws, ws_bytes, stream)
// FW signature (argsort): emits indices only (saves a values write
// when the caller only needs the permutation):
// (batch, row_len, descending, x, y_idx, ws, ws_bytes, stream)
// BW signature (sort / msort): pure scatter `dx[indices[i]] = dy[i]`.
// The launcher zeros `dx` first via `cudaMemsetAsync`.
// (batch, row_len, dy, indices, dx, ws, ws_bytes, stream)
//
// Topk: full block-bitonic sort, take first k cells. Trailblazer caps
// `k ≤ 64` (LLM-inference range). `largest == 1` = top-k by value;
// `largest == 0` = bottom-k.
// FW signature:
// (batch, row_len, k, largest, x, y_vals, y_idx, ws, ws_bytes, stream)
// BW signature: same scatter pattern as sort BW; `k`-wide grad routed
// back to a zero-init `row_len`-wide `dx`.
// (batch, k, row_len, dy, indices, dx, ws, ws_bytes, stream)
//
// Unique (consecutive): one-block-per-row sweep; output ordering is
// atomic-counter racy (NOT input-order). The Rust plan layer chains
// sort + this kernel to implement the un-consecutive `unique` op.
// (batch, row_len, max_unique, x, y_vals, y_counts, counter,
// ws, ws_bytes, stream)
//
// Histogram (1-D uniform bins): atomic-add per bin. `lo` / `hi` are
// passed as `double` and cast to `T` inside the macro — keeps the FFI
// shape uniform across f32 / f64.
// (n, num_bins, lo_d, hi_d, x, output, ws, ws_bytes, stream)
//
// Bincount: atomic-add per index. Out-of-range indices (`< 0` or `>=
// num_bins`) silently dropped.
// (n, num_bins, x, output, ws, ws_bytes, stream)
//
// Searchsorted: per-query binary search in a 1-D sorted array.
// `right == 0` = lower_bound; `right == 1` = upper_bound. Trailblazer
// is "single sorted_seq shared across all queries"; batched-per-row
// is a follow-up.
// (num_queries, len_sorted, right, sorted_seq, values, output,
// ws, ws_bytes, stream)
//
// Dtype coverage:
// sort / argsort / msort FW: f32, f64, i32, i64.
// sort / msort BW: f32, f64 (FP grads only).
// topk FW + BW: f32, f64.
// unique (consecutive): f32, f64, i32.
// histogram: f32, f64.
// bincount: i32, i64 input → i32 counts.
// searchsorted: f32, f64, i32, i64.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
// ---------- sort FW (values + indices, ascending OR descending) ----------
/// Block-bitonic sort, f32. Emits sorted values + sorted indices
/// (saved-indices contract for BW). `descending == 0` = ascending.
pub fn baracuda_kernels_sort_f32_run(
batch: i32,
row_len: i32,
descending: i32,
x: *const c_void,
y_vals: *mut c_void,
y_idx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_sort_f32_can_implement` (baracuda kernels sort f32 can implement).
pub fn baracuda_kernels_sort_f32_can_implement(
batch: i32,
row_len: i32,
) -> i32;
/// Block-bitonic sort, f64.
pub fn baracuda_kernels_sort_f64_run(
batch: i32,
row_len: i32,
descending: i32,
x: *const c_void,
y_vals: *mut c_void,
y_idx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_sort_f64_can_implement` (baracuda kernels sort f64 can implement).
pub fn baracuda_kernels_sort_f64_can_implement(
batch: i32,
row_len: i32,
) -> i32;
/// Block-bitonic sort, i32.
pub fn baracuda_kernels_sort_i32_run(
batch: i32,
row_len: i32,
descending: i32,
x: *const c_void,
y_vals: *mut c_void,
y_idx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_sort_i32_can_implement` (baracuda kernels sort i32 can implement).
pub fn baracuda_kernels_sort_i32_can_implement(
batch: i32,
row_len: i32,
) -> i32;
/// Block-bitonic sort, i64.
pub fn baracuda_kernels_sort_i64_run(
batch: i32,
row_len: i32,
descending: i32,
x: *const c_void,
y_vals: *mut c_void,
y_idx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_sort_i64_can_implement` (baracuda kernels sort i64 can implement).
pub fn baracuda_kernels_sort_i64_can_implement(
batch: i32,
row_len: i32,
) -> i32;
// ---------- argsort FW (indices only) ----------
/// Block-bitonic argsort, f32. Returns indices; values not written.
pub fn baracuda_kernels_argsort_f32_run(
batch: i32,
row_len: i32,
descending: i32,
x: *const c_void,
y_idx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_argsort_f32_can_implement` (baracuda kernels argsort f32 can implement).
pub fn baracuda_kernels_argsort_f32_can_implement(
batch: i32,
row_len: i32,
) -> i32;
/// Block-bitonic argsort, f64.
pub fn baracuda_kernels_argsort_f64_run(
batch: i32,
row_len: i32,
descending: i32,
x: *const c_void,
y_idx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_argsort_f64_can_implement` (baracuda kernels argsort f64 can implement).
pub fn baracuda_kernels_argsort_f64_can_implement(
batch: i32,
row_len: i32,
) -> i32;
/// Block-bitonic argsort, i32.
pub fn baracuda_kernels_argsort_i32_run(
batch: i32,
row_len: i32,
descending: i32,
x: *const c_void,
y_idx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_argsort_i32_can_implement` (baracuda kernels argsort i32 can implement).
pub fn baracuda_kernels_argsort_i32_can_implement(
batch: i32,
row_len: i32,
) -> i32;
/// Block-bitonic argsort, i64.
pub fn baracuda_kernels_argsort_i64_run(
batch: i32,
row_len: i32,
descending: i32,
x: *const c_void,
y_idx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_argsort_i64_can_implement` (baracuda kernels argsort i64 can implement).
pub fn baracuda_kernels_argsort_i64_can_implement(
batch: i32,
row_len: i32,
) -> i32;
// ---------- Phase 36 (Fuel ask Gap 6a) — argsort dtype fanout ----------
//
// Block-bitonic argsort for the missing dtypes. Same `row_len ≤ 1024`
// cap as the original 4. FP8 E4M3 uses a wrapper struct on the C
// side that decodes to `float` for the comparator (the storage
// layer is still byte-identical to the raw `u8` buffer).
//
// The multi-block radix variant for `row_len > 1024` is reserved
// for a follow-up phase (Gap 6b in Fuel's brief) — it needs a
// substantially different kernel structure.
/// Block-bitonic argsort, u8.
pub fn baracuda_kernels_argsort_u8_run(
batch: i32,
row_len: i32,
descending: i32,
x: *const c_void,
y_idx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_argsort_u8_can_implement` (baracuda kernels argsort u8 can implement).
pub fn baracuda_kernels_argsort_u8_can_implement(batch: i32, row_len: i32) -> i32;
/// Block-bitonic argsort, i8.
pub fn baracuda_kernels_argsort_i8_run(
batch: i32,
row_len: i32,
descending: i32,
x: *const c_void,
y_idx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_argsort_i8_can_implement` (baracuda kernels argsort i8 can implement).
pub fn baracuda_kernels_argsort_i8_can_implement(batch: i32, row_len: i32) -> i32;
/// Block-bitonic argsort, u32.
pub fn baracuda_kernels_argsort_u32_run(
batch: i32,
row_len: i32,
descending: i32,
x: *const c_void,
y_idx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_argsort_u32_can_implement` (baracuda kernels argsort u32 can implement).
pub fn baracuda_kernels_argsort_u32_can_implement(batch: i32, row_len: i32) -> i32;
/// Block-bitonic argsort, i16.
pub fn baracuda_kernels_argsort_i16_run(
batch: i32,
row_len: i32,
descending: i32,
x: *const c_void,
y_idx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_argsort_i16_can_implement` (baracuda kernels argsort i16 can implement).
pub fn baracuda_kernels_argsort_i16_can_implement(batch: i32, row_len: i32) -> i32;
/// Block-bitonic argsort, bf16. Comparator uses native `__nv_bfloat16`
/// `operator<` (CUDA device-side intrinsics).
pub fn baracuda_kernels_argsort_bf16_run(
batch: i32,
row_len: i32,
descending: i32,
x: *const c_void,
y_idx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_argsort_bf16_can_implement` (baracuda kernels argsort bf16 can implement).
pub fn baracuda_kernels_argsort_bf16_can_implement(batch: i32, row_len: i32) -> i32;
/// Block-bitonic argsort, f16. Comparator uses native `__half`
/// `operator<`.
pub fn baracuda_kernels_argsort_f16_run(
batch: i32,
row_len: i32,
descending: i32,
x: *const c_void,
y_idx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_argsort_f16_can_implement` (baracuda kernels argsort f16 can implement).
pub fn baracuda_kernels_argsort_f16_can_implement(batch: i32, row_len: i32) -> i32;
/// Block-bitonic argsort, FP8 E4M3. Storage is byte-identical to
/// raw `u8`; the kernel wraps it in an `Fp8E4M3Sort` struct that
/// decodes to `float` in the comparator. Raw-byte buffer in, i32
/// index buffer out.
pub fn baracuda_kernels_argsort_fp8e4m3_run(
batch: i32,
row_len: i32,
descending: i32,
x: *const c_void,
y_idx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_argsort_fp8e4m3_can_implement` (baracuda kernels argsort fp8e4m3 can implement).
pub fn baracuda_kernels_argsort_fp8e4m3_can_implement(
batch: i32, row_len: i32,
) -> i32;
// ---------- Phase 40 (Fuel ask Gap 6b) — multi-block radix argsort ----------
//
// For `row_len > 1024` (block-bitonic cap). Uses CUB's
// `DeviceSegmentedRadixSort::SortPairs[Descending]` under the hood.
// Caller supplies a workspace blob; size is queried via the
// `_workspace_size` companion.
//
// Dtype coverage: f32, f64, i32, i64 (the common LLM top-k logit
// dtypes). bf16 / f16 / fp8 deferred — CUB radix-sort needs a
// `cub::Traits` specialization for non-native arithmetic types and
// those would require either a custom `decomposer` (CUB 2.5+) or a
// per-row cast-to-f32 pre-pass.
//
// Calling contract:
// * `row_len <= 1024` returns status 3 (unsupported) — caller
// should dispatch to the block-bitonic `argsort_<dt>_run`.
// * Workspace shortfall returns status 4 — caller must size the
// blob using `_workspace_size`.
// * The workspace blob is consumed in full (CUB temp + scratch
// keys/indices + per-row offsets). Layout is internal; treat
// the blob as opaque.
/// Multi-block radix argsort, f32, for `row_len > 1024`.
pub fn baracuda_kernels_argsort_f32_big_run(
batch: i32,
row_len: i32,
descending: i32,
x: *const c_void,
y_idx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_argsort_f32_big_can_implement` (baracuda kernels argsort f32 big can implement).
pub fn baracuda_kernels_argsort_f32_big_can_implement(batch: i32, row_len: i32) -> i32;
/// `baracuda_kernels_argsort_f32_big_workspace_size` (baracuda kernels argsort f32 big workspace size).
pub fn baracuda_kernels_argsort_f32_big_workspace_size(batch: i32, row_len: i32) -> usize;
/// Multi-block radix argsort, f64.
pub fn baracuda_kernels_argsort_f64_big_run(
batch: i32,
row_len: i32,
descending: i32,
x: *const c_void,
y_idx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_argsort_f64_big_can_implement` (baracuda kernels argsort f64 big can implement).
pub fn baracuda_kernels_argsort_f64_big_can_implement(batch: i32, row_len: i32) -> i32;
/// `baracuda_kernels_argsort_f64_big_workspace_size` (baracuda kernels argsort f64 big workspace size).
pub fn baracuda_kernels_argsort_f64_big_workspace_size(batch: i32, row_len: i32) -> usize;
/// Multi-block radix argsort, i32.
pub fn baracuda_kernels_argsort_i32_big_run(
batch: i32,
row_len: i32,
descending: i32,
x: *const c_void,
y_idx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_argsort_i32_big_can_implement` (baracuda kernels argsort i32 big can implement).
pub fn baracuda_kernels_argsort_i32_big_can_implement(batch: i32, row_len: i32) -> i32;
/// `baracuda_kernels_argsort_i32_big_workspace_size` (baracuda kernels argsort i32 big workspace size).
pub fn baracuda_kernels_argsort_i32_big_workspace_size(batch: i32, row_len: i32) -> usize;
/// Multi-block radix argsort, i64.
pub fn baracuda_kernels_argsort_i64_big_run(
batch: i32,
row_len: i32,
descending: i32,
x: *const c_void,
y_idx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_argsort_i64_big_can_implement` (baracuda kernels argsort i64 big can implement).
pub fn baracuda_kernels_argsort_i64_big_can_implement(batch: i32, row_len: i32) -> i32;
/// `baracuda_kernels_argsort_i64_big_workspace_size` (baracuda kernels argsort i64 big workspace size).
pub fn baracuda_kernels_argsort_i64_big_workspace_size(batch: i32, row_len: i32) -> usize;
// ---------- msort FW (stable; values + indices) ----------
/// Stable block-bitonic sort, f32. Tie-break on original index so
/// equal keys preserve input order.
pub fn baracuda_kernels_msort_f32_run(
batch: i32,
row_len: i32,
descending: i32,
x: *const c_void,
y_vals: *mut c_void,
y_idx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_msort_f32_can_implement` (baracuda kernels msort f32 can implement).
pub fn baracuda_kernels_msort_f32_can_implement(
batch: i32,
row_len: i32,
) -> i32;
/// Stable block-bitonic sort, f64.
pub fn baracuda_kernels_msort_f64_run(
batch: i32,
row_len: i32,
descending: i32,
x: *const c_void,
y_vals: *mut c_void,
y_idx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_msort_f64_can_implement` (baracuda kernels msort f64 can implement).
pub fn baracuda_kernels_msort_f64_can_implement(
batch: i32,
row_len: i32,
) -> i32;
/// Stable block-bitonic sort, i32.
pub fn baracuda_kernels_msort_i32_run(
batch: i32,
row_len: i32,
descending: i32,
x: *const c_void,
y_vals: *mut c_void,
y_idx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_msort_i32_can_implement` (baracuda kernels msort i32 can implement).
pub fn baracuda_kernels_msort_i32_can_implement(
batch: i32,
row_len: i32,
) -> i32;
/// Stable block-bitonic sort, i64.
pub fn baracuda_kernels_msort_i64_run(
batch: i32,
row_len: i32,
descending: i32,
x: *const c_void,
y_vals: *mut c_void,
y_idx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_msort_i64_can_implement` (baracuda kernels msort i64 can implement).
pub fn baracuda_kernels_msort_i64_can_implement(
batch: i32,
row_len: i32,
) -> i32;
// ---------- sort / msort BW (scatter via saved indices) ----------
/// Sort BW, f32. `dx[indices[i]] = dy[i]`; launcher zeros `dx` first.
pub fn baracuda_kernels_sort_backward_f32_run(
batch: i32,
row_len: i32,
dy: *const c_void,
indices: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_sort_backward_f32_can_implement` (baracuda kernels sort backward f32 can implement).
pub fn baracuda_kernels_sort_backward_f32_can_implement(
batch: i32,
row_len: i32,
) -> i32;
/// Sort BW, f64.
pub fn baracuda_kernels_sort_backward_f64_run(
batch: i32,
row_len: i32,
dy: *const c_void,
indices: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_sort_backward_f64_can_implement` (baracuda kernels sort backward f64 can implement).
pub fn baracuda_kernels_sort_backward_f64_can_implement(
batch: i32,
row_len: i32,
) -> i32;
/// Msort BW, f32. Same scatter as sort BW; distinct symbol kept for
/// FFI / telemetry parity.
pub fn baracuda_kernels_msort_backward_f32_run(
batch: i32,
row_len: i32,
dy: *const c_void,
indices: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_msort_backward_f32_can_implement` (baracuda kernels msort backward f32 can implement).
pub fn baracuda_kernels_msort_backward_f32_can_implement(
batch: i32,
row_len: i32,
) -> i32;
/// Msort BW, f64.
pub fn baracuda_kernels_msort_backward_f64_run(
batch: i32,
row_len: i32,
dy: *const c_void,
indices: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_msort_backward_f64_can_implement` (baracuda kernels msort backward f64 can implement).
pub fn baracuda_kernels_msort_backward_f64_can_implement(
batch: i32,
row_len: i32,
) -> i32;
// ---------- topk FW + BW ----------
/// Block-bitonic top-k, f32. Caps `k ≤ 64` and `row_len ≤ 1024`.
/// `largest == 1` = top-k by value; `largest == 0` = bottom-k.
pub fn baracuda_kernels_topk_f32_run(
batch: i32,
row_len: i32,
k: i32,
largest: i32,
x: *const c_void,
y_vals: *mut c_void,
y_idx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_topk_f32_can_implement` (baracuda kernels topk f32 can implement).
pub fn baracuda_kernels_topk_f32_can_implement(
batch: i32,
row_len: i32,
k: i32,
largest: i32,
) -> i32;
/// Block-bitonic top-k, f64.
pub fn baracuda_kernels_topk_f64_run(
batch: i32,
row_len: i32,
k: i32,
largest: i32,
x: *const c_void,
y_vals: *mut c_void,
y_idx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_topk_f64_can_implement` (baracuda kernels topk f64 can implement).
pub fn baracuda_kernels_topk_f64_can_implement(
batch: i32,
row_len: i32,
k: i32,
largest: i32,
) -> i32;
/// Top-k BW, f32. Scatter k-wide `dy` into `row_len`-wide `dx`
/// (zero-init) via saved indices.
pub fn baracuda_kernels_topk_backward_f32_run(
batch: i32,
k: i32,
row_len: i32,
dy: *const c_void,
indices: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_topk_backward_f32_can_implement` (baracuda kernels topk backward f32 can implement).
pub fn baracuda_kernels_topk_backward_f32_can_implement(
batch: i32,
k: i32,
row_len: i32,
) -> i32;
/// Top-k BW, f64.
pub fn baracuda_kernels_topk_backward_f64_run(
batch: i32,
k: i32,
row_len: i32,
dy: *const c_void,
indices: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_topk_backward_f64_can_implement` (baracuda kernels topk backward f64 can implement).
pub fn baracuda_kernels_topk_backward_f64_can_implement(
batch: i32,
k: i32,
row_len: i32,
) -> i32;
// ---------- unique_consecutive FW (no BW — set-valued) ----------
/// Unique-consecutive, f32. Emits one cell per run-start; output
/// slot order is atomic-counter race order. `counter[row]` holds
/// the actual unique count post-launch.
pub fn baracuda_kernels_unique_consecutive_f32_run(
batch: i32,
row_len: i32,
max_unique: i32,
x: *const c_void,
y_vals: *mut c_void,
y_counts: *mut c_void,
counter: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_unique_consecutive_f32_can_implement` (baracuda kernels unique consecutive f32 can implement).
pub fn baracuda_kernels_unique_consecutive_f32_can_implement(
batch: i32,
row_len: i32,
max_unique: i32,
x: *const c_void,
y_vals: *const c_void,
y_counts: *const c_void,
counter: *const c_void,
) -> i32;
/// Unique-consecutive, f64.
pub fn baracuda_kernels_unique_consecutive_f64_run(
batch: i32,
row_len: i32,
max_unique: i32,
x: *const c_void,
y_vals: *mut c_void,
y_counts: *mut c_void,
counter: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_unique_consecutive_f64_can_implement` (baracuda kernels unique consecutive f64 can implement).
pub fn baracuda_kernels_unique_consecutive_f64_can_implement(
batch: i32,
row_len: i32,
max_unique: i32,
x: *const c_void,
y_vals: *const c_void,
y_counts: *const c_void,
counter: *const c_void,
) -> i32;
/// Unique-consecutive, i32.
pub fn baracuda_kernels_unique_consecutive_i32_run(
batch: i32,
row_len: i32,
max_unique: i32,
x: *const c_void,
y_vals: *mut c_void,
y_counts: *mut c_void,
counter: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_unique_consecutive_i32_can_implement` (baracuda kernels unique consecutive i32 can implement).
pub fn baracuda_kernels_unique_consecutive_i32_can_implement(
batch: i32,
row_len: i32,
max_unique: i32,
x: *const c_void,
y_vals: *const c_void,
y_counts: *const c_void,
counter: *const c_void,
) -> i32;
// ---------- histogram FW (1-D uniform bins, FP input, i32 counts) ----------
/// 1-D histogram, f32 input. `lo` / `hi` passed as `double` —
/// kernel casts to `T` (keeps the FFI shape uniform across dtypes).
pub fn baracuda_kernels_histogram_f32_run(
n: i64,
num_bins: i32,
lo_d: f64,
hi_d: f64,
x: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_histogram_f32_can_implement` (baracuda kernels histogram f32 can implement).
pub fn baracuda_kernels_histogram_f32_can_implement(
n: i64,
num_bins: i32,
lo_d: f64,
hi_d: f64,
x: *const c_void,
output: *const c_void,
) -> i32;
/// 1-D histogram, f64 input.
pub fn baracuda_kernels_histogram_f64_run(
n: i64,
num_bins: i32,
lo_d: f64,
hi_d: f64,
x: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_histogram_f64_can_implement` (baracuda kernels histogram f64 can implement).
pub fn baracuda_kernels_histogram_f64_can_implement(
n: i64,
num_bins: i32,
lo_d: f64,
hi_d: f64,
x: *const c_void,
output: *const c_void,
) -> i32;
// ---------- bincount FW (int input, i32 counts) ----------
/// `bincount`, i32 input. Out-of-range (`< 0` or `>= num_bins`)
/// silently dropped.
pub fn baracuda_kernels_bincount_i32_run(
n: i64,
num_bins: i32,
x: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_bincount_i32_can_implement` (baracuda kernels bincount i32 can implement).
pub fn baracuda_kernels_bincount_i32_can_implement(
n: i64,
num_bins: i32,
x: *const c_void,
output: *const c_void,
) -> i32;
/// `bincount`, i64 input.
pub fn baracuda_kernels_bincount_i64_run(
n: i64,
num_bins: i32,
x: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_bincount_i64_can_implement` (baracuda kernels bincount i64 can implement).
pub fn baracuda_kernels_bincount_i64_can_implement(
n: i64,
num_bins: i32,
x: *const c_void,
output: *const c_void,
) -> i32;
// ---------- searchsorted FW (1-D sorted_seq, per-query binary search) ----------
/// `searchsorted`, f32. `right == 0` = lower_bound; `right == 1`
/// = upper_bound.
pub fn baracuda_kernels_searchsorted_f32_run(
num_queries: i64,
len_sorted: i32,
right: i32,
sorted_seq: *const c_void,
values: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_searchsorted_f32_can_implement` (baracuda kernels searchsorted f32 can implement).
pub fn baracuda_kernels_searchsorted_f32_can_implement(
num_queries: i64,
len_sorted: i32,
right: i32,
sorted_seq: *const c_void,
values: *const c_void,
output: *const c_void,
) -> i32;
/// `searchsorted`, f64.
pub fn baracuda_kernels_searchsorted_f64_run(
num_queries: i64,
len_sorted: i32,
right: i32,
sorted_seq: *const c_void,
values: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_searchsorted_f64_can_implement` (baracuda kernels searchsorted f64 can implement).
pub fn baracuda_kernels_searchsorted_f64_can_implement(
num_queries: i64,
len_sorted: i32,
right: i32,
sorted_seq: *const c_void,
values: *const c_void,
output: *const c_void,
) -> i32;
/// `searchsorted`, i32.
pub fn baracuda_kernels_searchsorted_i32_run(
num_queries: i64,
len_sorted: i32,
right: i32,
sorted_seq: *const c_void,
values: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_searchsorted_i32_can_implement` (baracuda kernels searchsorted i32 can implement).
pub fn baracuda_kernels_searchsorted_i32_can_implement(
num_queries: i64,
len_sorted: i32,
right: i32,
sorted_seq: *const c_void,
values: *const c_void,
output: *const c_void,
) -> i32;
/// `searchsorted`, i64.
pub fn baracuda_kernels_searchsorted_i64_run(
num_queries: i64,
len_sorted: i32,
right: i32,
sorted_seq: *const c_void,
values: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_searchsorted_i64_can_implement` (baracuda kernels searchsorted i64 can implement).
pub fn baracuda_kernels_searchsorted_i64_can_implement(
num_queries: i64,
len_sorted: i32,
right: i32,
sorted_seq: *const c_void,
values: *const c_void,
output: *const c_void,
) -> i32;
// ----- WriteSlice (Phase 13.1) -----------------------------------------
//
// `dest[start_0..end_0, ..., start_{N-1}..end_{N-1}] = source`
// (assign, not accumulate). Both tensors contiguous, zero-offset.
// 1 <= rank <= 8. Five byte-width-dispatched symbols cover all
// byte-aligned dtypes (one per `sizeof(T)` ∈ {1, 2, 4, 8, 16});
// one nibble-packed symbol covers S4 / U4 with the even-alignment
// constraint on the innermost axis.
/// WriteSlice, 1-byte element (i8 / u8 / S8 / U8 / Bool / Fp8E4M3 /
/// Fp8E5M2). Generic per-slab-element memcpy kernel.
///
/// # Safety
/// `dest` must be a contiguous device buffer of at least
/// `prod(dest_shape) * 1` bytes. `source` must be a contiguous device
/// buffer of at least `source_numel * 1` bytes. `dest_shape`,
/// `source_shape`, `range_start` are host-side `int32_t` arrays of
/// length `rank`. `stream` must be a live CUDA stream.
pub fn baracuda_kernels_write_slice_b1_run(
dest: *mut c_void,
source: *const c_void,
source_numel: i64,
rank: i32,
dest_shape: *const i32,
source_shape: *const i32,
range_start: *const i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_write_slice_b1`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_write_slice_b1_can_implement(
dest: *const c_void,
source: *const c_void,
source_numel: i64,
rank: i32,
dest_shape: *const i32,
source_shape: *const i32,
range_start: *const i32,
) -> i32;
/// WriteSlice, 2-byte element (f16 / bf16). See `b1` variant for
/// the contract.
pub fn baracuda_kernels_write_slice_b2_run(
dest: *mut c_void,
source: *const c_void,
source_numel: i64,
rank: i32,
dest_shape: *const i32,
source_shape: *const i32,
range_start: *const i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_write_slice_b2`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_write_slice_b2_can_implement(
dest: *const c_void,
source: *const c_void,
source_numel: i64,
rank: i32,
dest_shape: *const i32,
source_shape: *const i32,
range_start: *const i32,
) -> i32;
/// WriteSlice, 4-byte element (f32 / F32Strict / i32).
pub fn baracuda_kernels_write_slice_b4_run(
dest: *mut c_void,
source: *const c_void,
source_numel: i64,
rank: i32,
dest_shape: *const i32,
source_shape: *const i32,
range_start: *const i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_write_slice_b4`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_write_slice_b4_can_implement(
dest: *const c_void,
source: *const c_void,
source_numel: i64,
rank: i32,
dest_shape: *const i32,
source_shape: *const i32,
range_start: *const i32,
) -> i32;
/// WriteSlice, 8-byte element (f64 / i64 / Complex32).
pub fn baracuda_kernels_write_slice_b8_run(
dest: *mut c_void,
source: *const c_void,
source_numel: i64,
rank: i32,
dest_shape: *const i32,
source_shape: *const i32,
range_start: *const i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_write_slice_b8`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_write_slice_b8_can_implement(
dest: *const c_void,
source: *const c_void,
source_numel: i64,
rank: i32,
dest_shape: *const i32,
source_shape: *const i32,
range_start: *const i32,
) -> i32;
/// WriteSlice, 16-byte element (Complex64).
pub fn baracuda_kernels_write_slice_b16_run(
dest: *mut c_void,
source: *const c_void,
source_numel: i64,
rank: i32,
dest_shape: *const i32,
source_shape: *const i32,
range_start: *const i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_write_slice_b16`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_write_slice_b16_can_implement(
dest: *const c_void,
source: *const c_void,
source_numel: i64,
rank: i32,
dest_shape: *const i32,
source_shape: *const i32,
range_start: *const i32,
) -> i32;
/// WriteSlice, nibble-packed (S4 / U4 — two elements per byte).
/// Constraint: `range_start[rank-1]` and `range_end[rank-1]` must
/// both be even so no read-modify-write straddles a byte boundary.
/// Shape / range_start arrays passed in are *byte*-counted on the
/// innermost axis (Rust side halves before calling).
///
/// # Safety
/// Same buffer / stream contract as the byte-aligned variants.
pub fn baracuda_kernels_write_slice_nibble_run(
dest: *mut c_void,
source: *const c_void,
source_byte_numel: i64,
rank: i32,
dest_byte_shape: *const i32,
source_byte_shape: *const i32,
range_start_bytes: *const i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// Implementability check for `baracuda_kernels_write_slice_nibble`. Host-side only.
///
/// # Safety
/// No device dereferences; same pointer-validity contract as the matching `_run`.
pub fn baracuda_kernels_write_slice_nibble_can_implement(
dest: *const c_void,
source: *const c_void,
source_byte_numel: i64,
rank: i32,
dest_byte_shape: *const i32,
source_byte_shape: *const i32,
range_start_bytes: *const i32,
) -> i32;
// ========================================================================
// Phase 16.2 — LpPool 1d/2d fused bespoke kernels.
// ========================================================================
//
// `y = (Σ_{k in window} |x_k|^p)^(1/p)` over an NCL (1d) / NCHW (2d)
// tensor. No padding (PyTorch's `LpPool*d` has no `pad` argument).
// Window clamps to input boundary when `ceil_mode = true` produces a
// window that overhangs.
//
// `norm_p` is `f32` for every dtype (f64 path included — the
// exponent precision rarely matters past 5 decimal places and the
// f32 ABI keeps the symbol set uniform).
//
// BW caller must zero `dx` before launch — kernel uses atomicAdd
// scatter.
//
// The `ceil_mode` flag is accepted for ABI uniformity but is
// *unused* by the kernel — the safe layer computes the output
// extent under the requested ceil_mode and passes it as `l_out` /
// `(h_out, w_out)`. Kernel just iterates `[0, l_out)`.
/// LpPool 1d FW, f32.
pub fn baracuda_kernels_lp_pool_1d_f32_run(
x: *const c_void, y: *mut c_void,
batch: i32, channels: i32, l_in: i32,
kernel: i32, stride: i32, l_out: i32,
norm_p: f32, ceil_mode: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_lp_pool_1d_f32_can_implement` (baracuda kernels lp pool 1d f32 can implement).
pub fn baracuda_kernels_lp_pool_1d_f32_can_implement(
x: *const c_void, y: *const c_void,
batch: i32, channels: i32, l_in: i32,
kernel: i32, stride: i32, l_out: i32,
norm_p: f32, ceil_mode: i32,
) -> i32;
/// LpPool 1d FW, f64.
pub fn baracuda_kernels_lp_pool_1d_f64_run(
x: *const c_void, y: *mut c_void,
batch: i32, channels: i32, l_in: i32,
kernel: i32, stride: i32, l_out: i32,
norm_p: f32, ceil_mode: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_lp_pool_1d_f64_can_implement` (baracuda kernels lp pool 1d f64 can implement).
pub fn baracuda_kernels_lp_pool_1d_f64_can_implement(
x: *const c_void, y: *const c_void,
batch: i32, channels: i32, l_in: i32,
kernel: i32, stride: i32, l_out: i32,
norm_p: f32, ceil_mode: i32,
) -> i32;
/// LpPool 1d FW, f16.
pub fn baracuda_kernels_lp_pool_1d_f16_run(
x: *const c_void, y: *mut c_void,
batch: i32, channels: i32, l_in: i32,
kernel: i32, stride: i32, l_out: i32,
norm_p: f32, ceil_mode: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_lp_pool_1d_f16_can_implement` (baracuda kernels lp pool 1d f16 can implement).
pub fn baracuda_kernels_lp_pool_1d_f16_can_implement(
x: *const c_void, y: *const c_void,
batch: i32, channels: i32, l_in: i32,
kernel: i32, stride: i32, l_out: i32,
norm_p: f32, ceil_mode: i32,
) -> i32;
/// LpPool 1d FW, bf16.
pub fn baracuda_kernels_lp_pool_1d_bf16_run(
x: *const c_void, y: *mut c_void,
batch: i32, channels: i32, l_in: i32,
kernel: i32, stride: i32, l_out: i32,
norm_p: f32, ceil_mode: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_lp_pool_1d_bf16_can_implement` (baracuda kernels lp pool 1d bf16 can implement).
pub fn baracuda_kernels_lp_pool_1d_bf16_can_implement(
x: *const c_void, y: *const c_void,
batch: i32, channels: i32, l_in: i32,
kernel: i32, stride: i32, l_out: i32,
norm_p: f32, ceil_mode: i32,
) -> i32;
/// LpPool 1d BW, f32. Caller must zero `dx` first.
pub fn baracuda_kernels_lp_pool_1d_f32_backward_run(
x: *const c_void, y: *const c_void, dy: *const c_void, dx: *mut c_void,
batch: i32, channels: i32, l_in: i32,
kernel: i32, stride: i32, l_out: i32,
norm_p: f32, ceil_mode: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_lp_pool_1d_f32_backward_can_implement` (baracuda kernels lp pool 1d f32 backward can implement).
pub fn baracuda_kernels_lp_pool_1d_f32_backward_can_implement(
x: *const c_void, y: *const c_void, dy: *const c_void, dx: *const c_void,
batch: i32, channels: i32, l_in: i32,
kernel: i32, stride: i32, l_out: i32,
norm_p: f32, ceil_mode: i32,
) -> i32;
/// LpPool 1d BW, f64.
pub fn baracuda_kernels_lp_pool_1d_f64_backward_run(
x: *const c_void, y: *const c_void, dy: *const c_void, dx: *mut c_void,
batch: i32, channels: i32, l_in: i32,
kernel: i32, stride: i32, l_out: i32,
norm_p: f32, ceil_mode: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_lp_pool_1d_f64_backward_can_implement` (baracuda kernels lp pool 1d f64 backward can implement).
pub fn baracuda_kernels_lp_pool_1d_f64_backward_can_implement(
x: *const c_void, y: *const c_void, dy: *const c_void, dx: *const c_void,
batch: i32, channels: i32, l_in: i32,
kernel: i32, stride: i32, l_out: i32,
norm_p: f32, ceil_mode: i32,
) -> i32;
/// LpPool 1d BW, f16.
pub fn baracuda_kernels_lp_pool_1d_f16_backward_run(
x: *const c_void, y: *const c_void, dy: *const c_void, dx: *mut c_void,
batch: i32, channels: i32, l_in: i32,
kernel: i32, stride: i32, l_out: i32,
norm_p: f32, ceil_mode: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_lp_pool_1d_f16_backward_can_implement` (baracuda kernels lp pool 1d f16 backward can implement).
pub fn baracuda_kernels_lp_pool_1d_f16_backward_can_implement(
x: *const c_void, y: *const c_void, dy: *const c_void, dx: *const c_void,
batch: i32, channels: i32, l_in: i32,
kernel: i32, stride: i32, l_out: i32,
norm_p: f32, ceil_mode: i32,
) -> i32;
/// LpPool 1d BW, bf16.
pub fn baracuda_kernels_lp_pool_1d_bf16_backward_run(
x: *const c_void, y: *const c_void, dy: *const c_void, dx: *mut c_void,
batch: i32, channels: i32, l_in: i32,
kernel: i32, stride: i32, l_out: i32,
norm_p: f32, ceil_mode: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_lp_pool_1d_bf16_backward_can_implement` (baracuda kernels lp pool 1d bf16 backward can implement).
pub fn baracuda_kernels_lp_pool_1d_bf16_backward_can_implement(
x: *const c_void, y: *const c_void, dy: *const c_void, dx: *const c_void,
batch: i32, channels: i32, l_in: i32,
kernel: i32, stride: i32, l_out: i32,
norm_p: f32, ceil_mode: i32,
) -> i32;
/// LpPool 2d FW, f32.
pub fn baracuda_kernels_lp_pool_2d_f32_run(
x: *const c_void, y: *mut c_void,
batch: i32, channels: i32, h_in: i32, w_in: i32,
kh: i32, kw: i32, sh: i32, sw: i32,
h_out: i32, w_out: i32,
norm_p: f32, ceil_mode: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_lp_pool_2d_f32_can_implement` (baracuda kernels lp pool 2d f32 can implement).
pub fn baracuda_kernels_lp_pool_2d_f32_can_implement(
x: *const c_void, y: *const c_void,
batch: i32, channels: i32, h_in: i32, w_in: i32,
kh: i32, kw: i32, sh: i32, sw: i32,
h_out: i32, w_out: i32,
norm_p: f32, ceil_mode: i32,
) -> i32;
/// LpPool 2d FW, f64.
pub fn baracuda_kernels_lp_pool_2d_f64_run(
x: *const c_void, y: *mut c_void,
batch: i32, channels: i32, h_in: i32, w_in: i32,
kh: i32, kw: i32, sh: i32, sw: i32,
h_out: i32, w_out: i32,
norm_p: f32, ceil_mode: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_lp_pool_2d_f64_can_implement` (baracuda kernels lp pool 2d f64 can implement).
pub fn baracuda_kernels_lp_pool_2d_f64_can_implement(
x: *const c_void, y: *const c_void,
batch: i32, channels: i32, h_in: i32, w_in: i32,
kh: i32, kw: i32, sh: i32, sw: i32,
h_out: i32, w_out: i32,
norm_p: f32, ceil_mode: i32,
) -> i32;
/// LpPool 2d FW, f16.
pub fn baracuda_kernels_lp_pool_2d_f16_run(
x: *const c_void, y: *mut c_void,
batch: i32, channels: i32, h_in: i32, w_in: i32,
kh: i32, kw: i32, sh: i32, sw: i32,
h_out: i32, w_out: i32,
norm_p: f32, ceil_mode: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_lp_pool_2d_f16_can_implement` (baracuda kernels lp pool 2d f16 can implement).
pub fn baracuda_kernels_lp_pool_2d_f16_can_implement(
x: *const c_void, y: *const c_void,
batch: i32, channels: i32, h_in: i32, w_in: i32,
kh: i32, kw: i32, sh: i32, sw: i32,
h_out: i32, w_out: i32,
norm_p: f32, ceil_mode: i32,
) -> i32;
/// LpPool 2d FW, bf16.
pub fn baracuda_kernels_lp_pool_2d_bf16_run(
x: *const c_void, y: *mut c_void,
batch: i32, channels: i32, h_in: i32, w_in: i32,
kh: i32, kw: i32, sh: i32, sw: i32,
h_out: i32, w_out: i32,
norm_p: f32, ceil_mode: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_lp_pool_2d_bf16_can_implement` (baracuda kernels lp pool 2d bf16 can implement).
pub fn baracuda_kernels_lp_pool_2d_bf16_can_implement(
x: *const c_void, y: *const c_void,
batch: i32, channels: i32, h_in: i32, w_in: i32,
kh: i32, kw: i32, sh: i32, sw: i32,
h_out: i32, w_out: i32,
norm_p: f32, ceil_mode: i32,
) -> i32;
/// LpPool 2d BW, f32. Caller must zero `dx` first.
pub fn baracuda_kernels_lp_pool_2d_f32_backward_run(
x: *const c_void, y: *const c_void, dy: *const c_void, dx: *mut c_void,
batch: i32, channels: i32, h_in: i32, w_in: i32,
kh: i32, kw: i32, sh: i32, sw: i32,
h_out: i32, w_out: i32,
norm_p: f32, ceil_mode: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_lp_pool_2d_f32_backward_can_implement` (baracuda kernels lp pool 2d f32 backward can implement).
pub fn baracuda_kernels_lp_pool_2d_f32_backward_can_implement(
x: *const c_void, y: *const c_void, dy: *const c_void, dx: *const c_void,
batch: i32, channels: i32, h_in: i32, w_in: i32,
kh: i32, kw: i32, sh: i32, sw: i32,
h_out: i32, w_out: i32,
norm_p: f32, ceil_mode: i32,
) -> i32;
/// LpPool 2d BW, f64.
pub fn baracuda_kernels_lp_pool_2d_f64_backward_run(
x: *const c_void, y: *const c_void, dy: *const c_void, dx: *mut c_void,
batch: i32, channels: i32, h_in: i32, w_in: i32,
kh: i32, kw: i32, sh: i32, sw: i32,
h_out: i32, w_out: i32,
norm_p: f32, ceil_mode: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_lp_pool_2d_f64_backward_can_implement` (baracuda kernels lp pool 2d f64 backward can implement).
pub fn baracuda_kernels_lp_pool_2d_f64_backward_can_implement(
x: *const c_void, y: *const c_void, dy: *const c_void, dx: *const c_void,
batch: i32, channels: i32, h_in: i32, w_in: i32,
kh: i32, kw: i32, sh: i32, sw: i32,
h_out: i32, w_out: i32,
norm_p: f32, ceil_mode: i32,
) -> i32;
/// LpPool 2d BW, f16.
pub fn baracuda_kernels_lp_pool_2d_f16_backward_run(
x: *const c_void, y: *const c_void, dy: *const c_void, dx: *mut c_void,
batch: i32, channels: i32, h_in: i32, w_in: i32,
kh: i32, kw: i32, sh: i32, sw: i32,
h_out: i32, w_out: i32,
norm_p: f32, ceil_mode: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_lp_pool_2d_f16_backward_can_implement` (baracuda kernels lp pool 2d f16 backward can implement).
pub fn baracuda_kernels_lp_pool_2d_f16_backward_can_implement(
x: *const c_void, y: *const c_void, dy: *const c_void, dx: *const c_void,
batch: i32, channels: i32, h_in: i32, w_in: i32,
kh: i32, kw: i32, sh: i32, sw: i32,
h_out: i32, w_out: i32,
norm_p: f32, ceil_mode: i32,
) -> i32;
/// LpPool 2d BW, bf16.
pub fn baracuda_kernels_lp_pool_2d_bf16_backward_run(
x: *const c_void, y: *const c_void, dy: *const c_void, dx: *mut c_void,
batch: i32, channels: i32, h_in: i32, w_in: i32,
kh: i32, kw: i32, sh: i32, sw: i32,
h_out: i32, w_out: i32,
norm_p: f32, ceil_mode: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_lp_pool_2d_bf16_backward_can_implement` (baracuda kernels lp pool 2d bf16 backward can implement).
pub fn baracuda_kernels_lp_pool_2d_bf16_backward_can_implement(
x: *const c_void, y: *const c_void, dy: *const c_void, dx: *const c_void,
batch: i32, channels: i32, h_in: i32, w_in: i32,
kh: i32, kw: i32, sh: i32, sw: i32,
h_out: i32, w_out: i32,
norm_p: f32, ceil_mode: i32,
) -> i32;
// ========================================================================
// Phase 16.3 — FractionalMaxPool 2-D + 3-D (FW + BW × 4 FP dtypes).
// ========================================================================
//
// Bespoke kernel; cuDNN has no fractional-pool primitive. Caller
// provides a `[N, C, num_axes]` f32 `random_samples` buffer (one α
// per (batch, channel, axis); see baracuda_fractional_max_pool.cuh
// for the "evenly-spaced base + α perturbation" window-placement
// formula and the note on PyTorch divergence).
//
// FW writes both `y` (per-window max, dtype `T`) and `indices`
// (per-window argmax linear index into `x`, dtype i64) — the
// saved-indices pattern shared with MaxPool BW.
//
// BW: one thread per output cell; reads `dy[i]` + `indices[i]`,
// atomicAdd's `dy[i]` into `dx[indices[i]]`. Caller must zero `dx`
// before launch. half / bf16 route through the atomicCAS helper
// (Phase 11.3 / Fuel feedback #6).
/// FractionalMaxPool2d FW, f32.
pub fn baracuda_kernels_fractional_max_pool_2d_fw_f32_run(
x: *const c_void, y: *mut c_void,
indices: *mut c_void,
random_samples: *const f32,
batch: i32, channels: i32,
h_in: i32, w_in: i32,
h_out: i32, w_out: i32,
kh: i32, kw: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fractional_max_pool_2d_fw_f32_can_implement` (baracuda kernels fractional max pool 2d fw f32 can implement).
pub fn baracuda_kernels_fractional_max_pool_2d_fw_f32_can_implement(
x: *const c_void, y: *const c_void,
indices: *const c_void,
random_samples: *const f32,
batch: i32, channels: i32,
h_in: i32, w_in: i32,
h_out: i32, w_out: i32,
kh: i32, kw: i32,
) -> i32;
/// FractionalMaxPool2d FW, f64.
pub fn baracuda_kernels_fractional_max_pool_2d_fw_f64_run(
x: *const c_void, y: *mut c_void,
indices: *mut c_void,
random_samples: *const f32,
batch: i32, channels: i32,
h_in: i32, w_in: i32,
h_out: i32, w_out: i32,
kh: i32, kw: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fractional_max_pool_2d_fw_f64_can_implement` (baracuda kernels fractional max pool 2d fw f64 can implement).
pub fn baracuda_kernels_fractional_max_pool_2d_fw_f64_can_implement(
x: *const c_void, y: *const c_void,
indices: *const c_void,
random_samples: *const f32,
batch: i32, channels: i32,
h_in: i32, w_in: i32,
h_out: i32, w_out: i32,
kh: i32, kw: i32,
) -> i32;
/// FractionalMaxPool2d FW, f16.
pub fn baracuda_kernels_fractional_max_pool_2d_fw_f16_run(
x: *const c_void, y: *mut c_void,
indices: *mut c_void,
random_samples: *const f32,
batch: i32, channels: i32,
h_in: i32, w_in: i32,
h_out: i32, w_out: i32,
kh: i32, kw: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fractional_max_pool_2d_fw_f16_can_implement` (baracuda kernels fractional max pool 2d fw f16 can implement).
pub fn baracuda_kernels_fractional_max_pool_2d_fw_f16_can_implement(
x: *const c_void, y: *const c_void,
indices: *const c_void,
random_samples: *const f32,
batch: i32, channels: i32,
h_in: i32, w_in: i32,
h_out: i32, w_out: i32,
kh: i32, kw: i32,
) -> i32;
/// FractionalMaxPool2d FW, bf16.
pub fn baracuda_kernels_fractional_max_pool_2d_fw_bf16_run(
x: *const c_void, y: *mut c_void,
indices: *mut c_void,
random_samples: *const f32,
batch: i32, channels: i32,
h_in: i32, w_in: i32,
h_out: i32, w_out: i32,
kh: i32, kw: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fractional_max_pool_2d_fw_bf16_can_implement` (baracuda kernels fractional max pool 2d fw bf16 can implement).
pub fn baracuda_kernels_fractional_max_pool_2d_fw_bf16_can_implement(
x: *const c_void, y: *const c_void,
indices: *const c_void,
random_samples: *const f32,
batch: i32, channels: i32,
h_in: i32, w_in: i32,
h_out: i32, w_out: i32,
kh: i32, kw: i32,
) -> i32;
/// FractionalMaxPool2d BW, f32.
pub fn baracuda_kernels_fractional_max_pool_2d_bw_f32_run(
dy: *const c_void,
indices: *const c_void,
dx: *mut c_void,
batch: i32, channels: i32,
h_in: i32, w_in: i32,
h_out: i32, w_out: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fractional_max_pool_2d_bw_f32_can_implement` (baracuda kernels fractional max pool 2d bw f32 can implement).
pub fn baracuda_kernels_fractional_max_pool_2d_bw_f32_can_implement(
dy: *const c_void,
indices: *const c_void,
dx: *const c_void,
batch: i32, channels: i32,
h_in: i32, w_in: i32,
h_out: i32, w_out: i32,
) -> i32;
/// FractionalMaxPool2d BW, f64.
pub fn baracuda_kernels_fractional_max_pool_2d_bw_f64_run(
dy: *const c_void,
indices: *const c_void,
dx: *mut c_void,
batch: i32, channels: i32,
h_in: i32, w_in: i32,
h_out: i32, w_out: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fractional_max_pool_2d_bw_f64_can_implement` (baracuda kernels fractional max pool 2d bw f64 can implement).
pub fn baracuda_kernels_fractional_max_pool_2d_bw_f64_can_implement(
dy: *const c_void,
indices: *const c_void,
dx: *const c_void,
batch: i32, channels: i32,
h_in: i32, w_in: i32,
h_out: i32, w_out: i32,
) -> i32;
/// FractionalMaxPool2d BW, f16.
pub fn baracuda_kernels_fractional_max_pool_2d_bw_f16_run(
dy: *const c_void,
indices: *const c_void,
dx: *mut c_void,
batch: i32, channels: i32,
h_in: i32, w_in: i32,
h_out: i32, w_out: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fractional_max_pool_2d_bw_f16_can_implement` (baracuda kernels fractional max pool 2d bw f16 can implement).
pub fn baracuda_kernels_fractional_max_pool_2d_bw_f16_can_implement(
dy: *const c_void,
indices: *const c_void,
dx: *const c_void,
batch: i32, channels: i32,
h_in: i32, w_in: i32,
h_out: i32, w_out: i32,
) -> i32;
/// FractionalMaxPool2d BW, bf16.
pub fn baracuda_kernels_fractional_max_pool_2d_bw_bf16_run(
dy: *const c_void,
indices: *const c_void,
dx: *mut c_void,
batch: i32, channels: i32,
h_in: i32, w_in: i32,
h_out: i32, w_out: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fractional_max_pool_2d_bw_bf16_can_implement` (baracuda kernels fractional max pool 2d bw bf16 can implement).
pub fn baracuda_kernels_fractional_max_pool_2d_bw_bf16_can_implement(
dy: *const c_void,
indices: *const c_void,
dx: *const c_void,
batch: i32, channels: i32,
h_in: i32, w_in: i32,
h_out: i32, w_out: i32,
) -> i32;
/// FractionalMaxPool3d FW, f32.
pub fn baracuda_kernels_fractional_max_pool_3d_fw_f32_run(
x: *const c_void, y: *mut c_void,
indices: *mut c_void,
random_samples: *const f32,
batch: i32, channels: i32,
d_in: i32, h_in: i32, w_in: i32,
d_out: i32, h_out: i32, w_out: i32,
kd: i32, kh: i32, kw: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fractional_max_pool_3d_fw_f32_can_implement` (baracuda kernels fractional max pool 3d fw f32 can implement).
pub fn baracuda_kernels_fractional_max_pool_3d_fw_f32_can_implement(
x: *const c_void, y: *const c_void,
indices: *const c_void,
random_samples: *const f32,
batch: i32, channels: i32,
d_in: i32, h_in: i32, w_in: i32,
d_out: i32, h_out: i32, w_out: i32,
kd: i32, kh: i32, kw: i32,
) -> i32;
/// FractionalMaxPool3d FW, f64.
pub fn baracuda_kernels_fractional_max_pool_3d_fw_f64_run(
x: *const c_void, y: *mut c_void,
indices: *mut c_void,
random_samples: *const f32,
batch: i32, channels: i32,
d_in: i32, h_in: i32, w_in: i32,
d_out: i32, h_out: i32, w_out: i32,
kd: i32, kh: i32, kw: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fractional_max_pool_3d_fw_f64_can_implement` (baracuda kernels fractional max pool 3d fw f64 can implement).
pub fn baracuda_kernels_fractional_max_pool_3d_fw_f64_can_implement(
x: *const c_void, y: *const c_void,
indices: *const c_void,
random_samples: *const f32,
batch: i32, channels: i32,
d_in: i32, h_in: i32, w_in: i32,
d_out: i32, h_out: i32, w_out: i32,
kd: i32, kh: i32, kw: i32,
) -> i32;
/// FractionalMaxPool3d FW, f16.
pub fn baracuda_kernels_fractional_max_pool_3d_fw_f16_run(
x: *const c_void, y: *mut c_void,
indices: *mut c_void,
random_samples: *const f32,
batch: i32, channels: i32,
d_in: i32, h_in: i32, w_in: i32,
d_out: i32, h_out: i32, w_out: i32,
kd: i32, kh: i32, kw: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fractional_max_pool_3d_fw_f16_can_implement` (baracuda kernels fractional max pool 3d fw f16 can implement).
pub fn baracuda_kernels_fractional_max_pool_3d_fw_f16_can_implement(
x: *const c_void, y: *const c_void,
indices: *const c_void,
random_samples: *const f32,
batch: i32, channels: i32,
d_in: i32, h_in: i32, w_in: i32,
d_out: i32, h_out: i32, w_out: i32,
kd: i32, kh: i32, kw: i32,
) -> i32;
/// FractionalMaxPool3d FW, bf16.
pub fn baracuda_kernels_fractional_max_pool_3d_fw_bf16_run(
x: *const c_void, y: *mut c_void,
indices: *mut c_void,
random_samples: *const f32,
batch: i32, channels: i32,
d_in: i32, h_in: i32, w_in: i32,
d_out: i32, h_out: i32, w_out: i32,
kd: i32, kh: i32, kw: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fractional_max_pool_3d_fw_bf16_can_implement` (baracuda kernels fractional max pool 3d fw bf16 can implement).
pub fn baracuda_kernels_fractional_max_pool_3d_fw_bf16_can_implement(
x: *const c_void, y: *const c_void,
indices: *const c_void,
random_samples: *const f32,
batch: i32, channels: i32,
d_in: i32, h_in: i32, w_in: i32,
d_out: i32, h_out: i32, w_out: i32,
kd: i32, kh: i32, kw: i32,
) -> i32;
/// FractionalMaxPool3d BW, f32.
pub fn baracuda_kernels_fractional_max_pool_3d_bw_f32_run(
dy: *const c_void,
indices: *const c_void,
dx: *mut c_void,
batch: i32, channels: i32,
d_in: i32, h_in: i32, w_in: i32,
d_out: i32, h_out: i32, w_out: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fractional_max_pool_3d_bw_f32_can_implement` (baracuda kernels fractional max pool 3d bw f32 can implement).
pub fn baracuda_kernels_fractional_max_pool_3d_bw_f32_can_implement(
dy: *const c_void,
indices: *const c_void,
dx: *const c_void,
batch: i32, channels: i32,
d_in: i32, h_in: i32, w_in: i32,
d_out: i32, h_out: i32, w_out: i32,
) -> i32;
/// FractionalMaxPool3d BW, f64.
pub fn baracuda_kernels_fractional_max_pool_3d_bw_f64_run(
dy: *const c_void,
indices: *const c_void,
dx: *mut c_void,
batch: i32, channels: i32,
d_in: i32, h_in: i32, w_in: i32,
d_out: i32, h_out: i32, w_out: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fractional_max_pool_3d_bw_f64_can_implement` (baracuda kernels fractional max pool 3d bw f64 can implement).
pub fn baracuda_kernels_fractional_max_pool_3d_bw_f64_can_implement(
dy: *const c_void,
indices: *const c_void,
dx: *const c_void,
batch: i32, channels: i32,
d_in: i32, h_in: i32, w_in: i32,
d_out: i32, h_out: i32, w_out: i32,
) -> i32;
/// FractionalMaxPool3d BW, f16.
pub fn baracuda_kernels_fractional_max_pool_3d_bw_f16_run(
dy: *const c_void,
indices: *const c_void,
dx: *mut c_void,
batch: i32, channels: i32,
d_in: i32, h_in: i32, w_in: i32,
d_out: i32, h_out: i32, w_out: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fractional_max_pool_3d_bw_f16_can_implement` (baracuda kernels fractional max pool 3d bw f16 can implement).
pub fn baracuda_kernels_fractional_max_pool_3d_bw_f16_can_implement(
dy: *const c_void,
indices: *const c_void,
dx: *const c_void,
batch: i32, channels: i32,
d_in: i32, h_in: i32, w_in: i32,
d_out: i32, h_out: i32, w_out: i32,
) -> i32;
/// FractionalMaxPool3d BW, bf16.
pub fn baracuda_kernels_fractional_max_pool_3d_bw_bf16_run(
dy: *const c_void,
indices: *const c_void,
dx: *mut c_void,
batch: i32, channels: i32,
d_in: i32, h_in: i32, w_in: i32,
d_out: i32, h_out: i32, w_out: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_fractional_max_pool_3d_bw_bf16_can_implement` (baracuda kernels fractional max pool 3d bw bf16 can implement).
pub fn baracuda_kernels_fractional_max_pool_3d_bw_bf16_can_implement(
dy: *const c_void,
indices: *const c_void,
dx: *const c_void,
batch: i32, channels: i32,
d_in: i32, h_in: i32, w_in: i32,
d_out: i32, h_out: i32, w_out: i32,
) -> i32;
// ========================================================================
// Phase 16.1 — bit-exact PyTorch adaptive pooling (Avg / Max,
// 1D / 2D / 3D, FW + BW × 4 FP dtypes).
// ========================================================================
//
// Rank-agnostic kernel template; the Rust side packs the 1D / 2D /
// 3D spatial shape into the `(in_d, in_h, in_w)` / `(out_d, out_h,
// out_w)` i32 args with degenerate `1`s filling unused leading
// axes (1D → in_d=1, in_h=1, in_w=L; 2D → in_d=1, in_h=H, in_w=W;
// 3D → in_d=D, in_h=H, in_w=W).
//
// `nc` is the outer batch×channels product (the kernels iterate
// each NC slice independently with PyTorch's non-uniform
// per-output-cell window formula). `spatial_rank ∈ {1, 2, 3}`.
//
// MaxPool FW writes an i64 argmax `indices` tensor (linear offset
// within each per-NC spatial slab). MaxPool BW consumes that
// saved-indices tensor and atomically adds `dy` into the saved
// positions.
//
// AvgPool BW + MaxPool BW both zero `dx` internally before the
// scatter — callers do NOT need to pre-zero. half / bf16 atomics
// route through `baracuda::atomic::add` (Phase 11.3 atomicCAS
// helper).
//
// Replaces the Phase 11.8 cuDNN-approximation path (uniform
// `kernel = ceil(in/out)` / `stride = floor(in/out)`).
/// Adaptive AvgPool FW, f16. Rank-agnostic (`spatial_rank ∈ {1,2,3}`).
pub fn baracuda_kernels_adaptive_avg_pool_f16_fw_run(
x: *const c_void, y: *mut c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_adaptive_avg_pool_f16_fw_can_implement` (baracuda kernels adaptive avg pool f16 fw can implement).
pub fn baracuda_kernels_adaptive_avg_pool_f16_fw_can_implement(
x: *const c_void, y: *const c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
) -> i32;
/// Adaptive AvgPool FW, bf16.
pub fn baracuda_kernels_adaptive_avg_pool_bf16_fw_run(
x: *const c_void, y: *mut c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_adaptive_avg_pool_bf16_fw_can_implement` (baracuda kernels adaptive avg pool bf16 fw can implement).
pub fn baracuda_kernels_adaptive_avg_pool_bf16_fw_can_implement(
x: *const c_void, y: *const c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
) -> i32;
/// Adaptive AvgPool FW, f32.
pub fn baracuda_kernels_adaptive_avg_pool_f32_fw_run(
x: *const c_void, y: *mut c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_adaptive_avg_pool_f32_fw_can_implement` (baracuda kernels adaptive avg pool f32 fw can implement).
pub fn baracuda_kernels_adaptive_avg_pool_f32_fw_can_implement(
x: *const c_void, y: *const c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
) -> i32;
/// Adaptive AvgPool FW, f64.
pub fn baracuda_kernels_adaptive_avg_pool_f64_fw_run(
x: *const c_void, y: *mut c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_adaptive_avg_pool_f64_fw_can_implement` (baracuda kernels adaptive avg pool f64 fw can implement).
pub fn baracuda_kernels_adaptive_avg_pool_f64_fw_can_implement(
x: *const c_void, y: *const c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
) -> i32;
/// Adaptive AvgPool BW, f16. Zeros `dx` internally, then atomic-scatters.
pub fn baracuda_kernels_adaptive_avg_pool_f16_bw_run(
dy: *const c_void, dx: *mut c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_adaptive_avg_pool_f16_bw_can_implement` (baracuda kernels adaptive avg pool f16 bw can implement).
pub fn baracuda_kernels_adaptive_avg_pool_f16_bw_can_implement(
dy: *const c_void, dx: *const c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
) -> i32;
/// Adaptive AvgPool BW, bf16.
pub fn baracuda_kernels_adaptive_avg_pool_bf16_bw_run(
dy: *const c_void, dx: *mut c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_adaptive_avg_pool_bf16_bw_can_implement` (baracuda kernels adaptive avg pool bf16 bw can implement).
pub fn baracuda_kernels_adaptive_avg_pool_bf16_bw_can_implement(
dy: *const c_void, dx: *const c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
) -> i32;
/// Adaptive AvgPool BW, f32.
pub fn baracuda_kernels_adaptive_avg_pool_f32_bw_run(
dy: *const c_void, dx: *mut c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_adaptive_avg_pool_f32_bw_can_implement` (baracuda kernels adaptive avg pool f32 bw can implement).
pub fn baracuda_kernels_adaptive_avg_pool_f32_bw_can_implement(
dy: *const c_void, dx: *const c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
) -> i32;
/// Adaptive AvgPool BW, f64.
pub fn baracuda_kernels_adaptive_avg_pool_f64_bw_run(
dy: *const c_void, dx: *mut c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_adaptive_avg_pool_f64_bw_can_implement` (baracuda kernels adaptive avg pool f64 bw can implement).
pub fn baracuda_kernels_adaptive_avg_pool_f64_bw_can_implement(
dy: *const c_void, dx: *const c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
) -> i32;
/// Adaptive MaxPool FW, f16. Writes `y` only — the matching BW
/// recomputes the argmax internally from the saved `x` (keeps the
/// Phase 11.8 args shape; no separate indices tensor).
pub fn baracuda_kernels_adaptive_max_pool_f16_fw_run(
x: *const c_void, y: *mut c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_adaptive_max_pool_f16_fw_can_implement` (baracuda kernels adaptive max pool f16 fw can implement).
pub fn baracuda_kernels_adaptive_max_pool_f16_fw_can_implement(
x: *const c_void, y: *const c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
) -> i32;
/// Adaptive MaxPool FW, bf16.
pub fn baracuda_kernels_adaptive_max_pool_bf16_fw_run(
x: *const c_void, y: *mut c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_adaptive_max_pool_bf16_fw_can_implement` (baracuda kernels adaptive max pool bf16 fw can implement).
pub fn baracuda_kernels_adaptive_max_pool_bf16_fw_can_implement(
x: *const c_void, y: *const c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
) -> i32;
/// Adaptive MaxPool FW, f32.
pub fn baracuda_kernels_adaptive_max_pool_f32_fw_run(
x: *const c_void, y: *mut c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_adaptive_max_pool_f32_fw_can_implement` (baracuda kernels adaptive max pool f32 fw can implement).
pub fn baracuda_kernels_adaptive_max_pool_f32_fw_can_implement(
x: *const c_void, y: *const c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
) -> i32;
/// Adaptive MaxPool FW, f64.
pub fn baracuda_kernels_adaptive_max_pool_f64_fw_run(
x: *const c_void, y: *mut c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_adaptive_max_pool_f64_fw_can_implement` (baracuda kernels adaptive max pool f64 fw can implement).
pub fn baracuda_kernels_adaptive_max_pool_f64_fw_can_implement(
x: *const c_void, y: *const c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
) -> i32;
/// Adaptive MaxPool BW, f16. Recomputes the per-window argmax from
/// the saved `x`, zeros `dx` internally, then atomic-scatters `dy`
/// into the argmax positions.
pub fn baracuda_kernels_adaptive_max_pool_f16_bw_run(
x: *const c_void, dy: *const c_void, dx: *mut c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_adaptive_max_pool_f16_bw_can_implement` (baracuda kernels adaptive max pool f16 bw can implement).
pub fn baracuda_kernels_adaptive_max_pool_f16_bw_can_implement(
x: *const c_void, dy: *const c_void, dx: *const c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
) -> i32;
/// Adaptive MaxPool BW, bf16.
pub fn baracuda_kernels_adaptive_max_pool_bf16_bw_run(
x: *const c_void, dy: *const c_void, dx: *mut c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_adaptive_max_pool_bf16_bw_can_implement` (baracuda kernels adaptive max pool bf16 bw can implement).
pub fn baracuda_kernels_adaptive_max_pool_bf16_bw_can_implement(
x: *const c_void, dy: *const c_void, dx: *const c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
) -> i32;
/// Adaptive MaxPool BW, f32.
pub fn baracuda_kernels_adaptive_max_pool_f32_bw_run(
x: *const c_void, dy: *const c_void, dx: *mut c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_adaptive_max_pool_f32_bw_can_implement` (baracuda kernels adaptive max pool f32 bw can implement).
pub fn baracuda_kernels_adaptive_max_pool_f32_bw_can_implement(
x: *const c_void, dy: *const c_void, dx: *const c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
) -> i32;
/// Adaptive MaxPool BW, f64.
pub fn baracuda_kernels_adaptive_max_pool_f64_bw_run(
x: *const c_void, dy: *const c_void, dx: *mut c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_adaptive_max_pool_f64_bw_can_implement` (baracuda kernels adaptive max pool f64 bw can implement).
pub fn baracuda_kernels_adaptive_max_pool_f64_bw_can_implement(
x: *const c_void, dy: *const c_void, dx: *const c_void,
nc: i32, spatial_rank: i32,
in_d: i32, in_h: i32, in_w: i32,
out_d: i32, out_h: i32, out_w: i32,
) -> i32;
// ========================================================================
// Phase 19.3 — im2col / im2col1d / col2im1d bespoke kernels.
// ========================================================================
//
// Building blocks for Fuel's conv-via-im2col-and-GEMM fallback
// lowering + the conv-backward filter-gradient path. Three ops ×
// four FP dtypes (f16, bf16, f32, f64) = 12 FFI symbols.
//
// im2col_2d: NCHW input `[N, C, H_in, W_in]` →
// col output `[N, C·kh·kw, h_out·w_out]`.
// im2col_1d: NCL input `[N, C, L_in]` →
// col output `[N, C·kl, l_out]`.
// col2im_1d: col input `[N, C·kl, l_out]` →
// NCL output `[N, C, L_in]` — inverse of im2col_1d.
//
// Output extents follow the standard PyTorch / cuDNN formula:
// h_out = (H_in + 2·pad_h - dilation_h·(kh-1) - 1) / stride_h + 1
// w_out, l_out: analogous.
//
// col2im_1d uses atomicAdd scatter for overlapping window cells
// (when stride < kernel). half/bf16 route through the 32-bit CAS
// path in `baracuda::atomic::add<T>` for universal availability.
// **Caller must pre-zero the output buffer before col2im_1d.**
/// im2col 2-D, f32.
pub fn baracuda_kernels_im2col_2d_f32_run(
batch: i32, channels: i32,
h_in: i32, w_in: i32,
h_out: i32, w_out: i32,
kh: i32, kw: i32,
stride_h: i32, stride_w: i32,
pad_h: i32, pad_w: i32,
dilation_h: i32, dilation_w: i32,
input: *const c_void,
output: *mut c_void,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_im2col_2d_f32_can_implement` (baracuda kernels im2col 2d f32 can implement).
pub fn baracuda_kernels_im2col_2d_f32_can_implement(
batch: i32, channels: i32,
h_in: i32, w_in: i32,
h_out: i32, w_out: i32,
kh: i32, kw: i32,
stride_h: i32, stride_w: i32,
pad_h: i32, pad_w: i32,
dilation_h: i32, dilation_w: i32,
input: *const c_void,
output: *const c_void,
) -> i32;
/// im2col 2-D, f64.
pub fn baracuda_kernels_im2col_2d_f64_run(
batch: i32, channels: i32,
h_in: i32, w_in: i32,
h_out: i32, w_out: i32,
kh: i32, kw: i32,
stride_h: i32, stride_w: i32,
pad_h: i32, pad_w: i32,
dilation_h: i32, dilation_w: i32,
input: *const c_void,
output: *mut c_void,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_im2col_2d_f64_can_implement` (baracuda kernels im2col 2d f64 can implement).
pub fn baracuda_kernels_im2col_2d_f64_can_implement(
batch: i32, channels: i32,
h_in: i32, w_in: i32,
h_out: i32, w_out: i32,
kh: i32, kw: i32,
stride_h: i32, stride_w: i32,
pad_h: i32, pad_w: i32,
dilation_h: i32, dilation_w: i32,
input: *const c_void,
output: *const c_void,
) -> i32;
/// im2col 2-D, f16.
pub fn baracuda_kernels_im2col_2d_f16_run(
batch: i32, channels: i32,
h_in: i32, w_in: i32,
h_out: i32, w_out: i32,
kh: i32, kw: i32,
stride_h: i32, stride_w: i32,
pad_h: i32, pad_w: i32,
dilation_h: i32, dilation_w: i32,
input: *const c_void,
output: *mut c_void,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_im2col_2d_f16_can_implement` (baracuda kernels im2col 2d f16 can implement).
pub fn baracuda_kernels_im2col_2d_f16_can_implement(
batch: i32, channels: i32,
h_in: i32, w_in: i32,
h_out: i32, w_out: i32,
kh: i32, kw: i32,
stride_h: i32, stride_w: i32,
pad_h: i32, pad_w: i32,
dilation_h: i32, dilation_w: i32,
input: *const c_void,
output: *const c_void,
) -> i32;
/// im2col 2-D, bf16.
pub fn baracuda_kernels_im2col_2d_bf16_run(
batch: i32, channels: i32,
h_in: i32, w_in: i32,
h_out: i32, w_out: i32,
kh: i32, kw: i32,
stride_h: i32, stride_w: i32,
pad_h: i32, pad_w: i32,
dilation_h: i32, dilation_w: i32,
input: *const c_void,
output: *mut c_void,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_im2col_2d_bf16_can_implement` (baracuda kernels im2col 2d bf16 can implement).
pub fn baracuda_kernels_im2col_2d_bf16_can_implement(
batch: i32, channels: i32,
h_in: i32, w_in: i32,
h_out: i32, w_out: i32,
kh: i32, kw: i32,
stride_h: i32, stride_w: i32,
pad_h: i32, pad_w: i32,
dilation_h: i32, dilation_w: i32,
input: *const c_void,
output: *const c_void,
) -> i32;
/// im2col 1-D, f32.
pub fn baracuda_kernels_im2col_1d_f32_run(
batch: i32, channels: i32,
l_in: i32, l_out: i32,
kl: i32, stride_l: i32, pad_l: i32, dilation_l: i32,
input: *const c_void,
output: *mut c_void,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_im2col_1d_f32_can_implement` (baracuda kernels im2col 1d f32 can implement).
pub fn baracuda_kernels_im2col_1d_f32_can_implement(
batch: i32, channels: i32,
l_in: i32, l_out: i32,
kl: i32, stride_l: i32, pad_l: i32, dilation_l: i32,
input: *const c_void,
output: *const c_void,
) -> i32;
/// im2col 1-D, f64.
pub fn baracuda_kernels_im2col_1d_f64_run(
batch: i32, channels: i32,
l_in: i32, l_out: i32,
kl: i32, stride_l: i32, pad_l: i32, dilation_l: i32,
input: *const c_void,
output: *mut c_void,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_im2col_1d_f64_can_implement` (baracuda kernels im2col 1d f64 can implement).
pub fn baracuda_kernels_im2col_1d_f64_can_implement(
batch: i32, channels: i32,
l_in: i32, l_out: i32,
kl: i32, stride_l: i32, pad_l: i32, dilation_l: i32,
input: *const c_void,
output: *const c_void,
) -> i32;
/// im2col 1-D, f16.
pub fn baracuda_kernels_im2col_1d_f16_run(
batch: i32, channels: i32,
l_in: i32, l_out: i32,
kl: i32, stride_l: i32, pad_l: i32, dilation_l: i32,
input: *const c_void,
output: *mut c_void,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_im2col_1d_f16_can_implement` (baracuda kernels im2col 1d f16 can implement).
pub fn baracuda_kernels_im2col_1d_f16_can_implement(
batch: i32, channels: i32,
l_in: i32, l_out: i32,
kl: i32, stride_l: i32, pad_l: i32, dilation_l: i32,
input: *const c_void,
output: *const c_void,
) -> i32;
/// im2col 1-D, bf16.
pub fn baracuda_kernels_im2col_1d_bf16_run(
batch: i32, channels: i32,
l_in: i32, l_out: i32,
kl: i32, stride_l: i32, pad_l: i32, dilation_l: i32,
input: *const c_void,
output: *mut c_void,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_im2col_1d_bf16_can_implement` (baracuda kernels im2col 1d bf16 can implement).
pub fn baracuda_kernels_im2col_1d_bf16_can_implement(
batch: i32, channels: i32,
l_in: i32, l_out: i32,
kl: i32, stride_l: i32, pad_l: i32, dilation_l: i32,
input: *const c_void,
output: *const c_void,
) -> i32;
/// col2im 1-D, f32. Caller must zero `output` first.
pub fn baracuda_kernels_col2im_1d_f32_run(
batch: i32, channels: i32,
l_in: i32, l_out: i32,
kl: i32, stride_l: i32, pad_l: i32, dilation_l: i32,
input: *const c_void,
output: *mut c_void,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_col2im_1d_f32_can_implement` (baracuda kernels col2im 1d f32 can implement).
pub fn baracuda_kernels_col2im_1d_f32_can_implement(
batch: i32, channels: i32,
l_in: i32, l_out: i32,
kl: i32, stride_l: i32, pad_l: i32, dilation_l: i32,
input: *const c_void,
output: *const c_void,
) -> i32;
/// col2im 1-D, f64. Caller must zero `output` first.
pub fn baracuda_kernels_col2im_1d_f64_run(
batch: i32, channels: i32,
l_in: i32, l_out: i32,
kl: i32, stride_l: i32, pad_l: i32, dilation_l: i32,
input: *const c_void,
output: *mut c_void,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_col2im_1d_f64_can_implement` (baracuda kernels col2im 1d f64 can implement).
pub fn baracuda_kernels_col2im_1d_f64_can_implement(
batch: i32, channels: i32,
l_in: i32, l_out: i32,
kl: i32, stride_l: i32, pad_l: i32, dilation_l: i32,
input: *const c_void,
output: *const c_void,
) -> i32;
/// col2im 1-D, f16. Caller must zero `output` first.
pub fn baracuda_kernels_col2im_1d_f16_run(
batch: i32, channels: i32,
l_in: i32, l_out: i32,
kl: i32, stride_l: i32, pad_l: i32, dilation_l: i32,
input: *const c_void,
output: *mut c_void,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_col2im_1d_f16_can_implement` (baracuda kernels col2im 1d f16 can implement).
pub fn baracuda_kernels_col2im_1d_f16_can_implement(
batch: i32, channels: i32,
l_in: i32, l_out: i32,
kl: i32, stride_l: i32, pad_l: i32, dilation_l: i32,
input: *const c_void,
output: *const c_void,
) -> i32;
/// col2im 1-D, bf16. Caller must zero `output` first.
pub fn baracuda_kernels_col2im_1d_bf16_run(
batch: i32, channels: i32,
l_in: i32, l_out: i32,
kl: i32, stride_l: i32, pad_l: i32, dilation_l: i32,
input: *const c_void,
output: *mut c_void,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_col2im_1d_bf16_can_implement` (baracuda kernels col2im 1d bf16 can implement).
pub fn baracuda_kernels_col2im_1d_bf16_can_implement(
batch: i32, channels: i32,
l_in: i32, l_out: i32,
kl: i32, stride_l: i32, pad_l: i32, dilation_l: i32,
input: *const c_void,
output: *const c_void,
) -> i32;
}
// ============================================================================
// Phase 19.2 — upsample (nearest 2D) FFI surface
// ============================================================================
//
// Adds the missing `nearest-2D` mode (FW + BW × 4 fp dtypes = 8 symbols)
// under the new `upsample_*` namespace specified by Phase 19.2's
// design-correction plan. The existing `interpolate_bilinear_2d_*`
// symbols (kernels/image/interpolate.cu) are re-exported under the
// `upsample_bilinear_2d_*` namespace via the alias block immediately
// below — same machine code, two symbol names. Bilinear coverage
// remains f32 + f64 only (matching the bespoke kernel set); f16 / bf16
// bilinear fanout is a fanout-milestone follow-up.
//
// BW kernels scatter via `atomicAdd` — **caller must pre-zero
// `dinput`** before launch (matches the bilinear BW contract).
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
// ---- upsample_nearest_2d FW (4 fp dtypes) ----
/// `upsample(x, mode='nearest')` FW, f32.
/// `input`: `[N, C, IH, IW]`; `output`: `[N, C, OH, OW]`. NCHW.
/// Coordinate mapping: nearest under `align_corners=false`.
/// # Safety: all pointers must be live device memory; `stream` valid.
pub fn baracuda_kernels_upsample_nearest_2d_fw_f32_run(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
input: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_upsample_nearest_2d_fw_f32_can_implement` (baracuda kernels upsample nearest 2d fw f32 can implement).
pub fn baracuda_kernels_upsample_nearest_2d_fw_f32_can_implement(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
input: *const c_void,
output: *const c_void,
) -> i32;
/// `upsample_nearest_2d` FW, f64. # Safety: as f32.
pub fn baracuda_kernels_upsample_nearest_2d_fw_f64_run(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
input: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_upsample_nearest_2d_fw_f64_can_implement` (baracuda kernels upsample nearest 2d fw f64 can implement).
pub fn baracuda_kernels_upsample_nearest_2d_fw_f64_can_implement(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
input: *const c_void,
output: *const c_void,
) -> i32;
/// `upsample_nearest_2d` FW, f16. # Safety: as f32.
pub fn baracuda_kernels_upsample_nearest_2d_fw_f16_run(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
input: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_upsample_nearest_2d_fw_f16_can_implement` (baracuda kernels upsample nearest 2d fw f16 can implement).
pub fn baracuda_kernels_upsample_nearest_2d_fw_f16_can_implement(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
input: *const c_void,
output: *const c_void,
) -> i32;
/// `upsample_nearest_2d` FW, bf16. # Safety: as f32.
pub fn baracuda_kernels_upsample_nearest_2d_fw_bf16_run(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
input: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_upsample_nearest_2d_fw_bf16_can_implement` (baracuda kernels upsample nearest 2d fw bf16 can implement).
pub fn baracuda_kernels_upsample_nearest_2d_fw_bf16_can_implement(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
input: *const c_void,
output: *const c_void,
) -> i32;
// ---- upsample_nearest_2d BW (4 fp dtypes; caller pre-zeros dinput) ----
/// `upsample_nearest_2d` BW, f32. Caller pre-zeros `dinput`.
/// # Safety: as FW.
pub fn baracuda_kernels_upsample_nearest_2d_bw_f32_run(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
dout: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_upsample_nearest_2d_bw_f32_can_implement` (baracuda kernels upsample nearest 2d bw f32 can implement).
pub fn baracuda_kernels_upsample_nearest_2d_bw_f32_can_implement(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
dout: *const c_void,
dinput: *const c_void,
) -> i32;
/// `upsample_nearest_2d` BW, f64. # Safety: as f32 BW.
pub fn baracuda_kernels_upsample_nearest_2d_bw_f64_run(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
dout: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_upsample_nearest_2d_bw_f64_can_implement` (baracuda kernels upsample nearest 2d bw f64 can implement).
pub fn baracuda_kernels_upsample_nearest_2d_bw_f64_can_implement(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
dout: *const c_void,
dinput: *const c_void,
) -> i32;
/// `upsample_nearest_2d` BW, f16. # Safety: as f32 BW. Uses the
/// `baracuda::atomic::add<__half>` (CAS-based) helper.
pub fn baracuda_kernels_upsample_nearest_2d_bw_f16_run(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
dout: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_upsample_nearest_2d_bw_f16_can_implement` (baracuda kernels upsample nearest 2d bw f16 can implement).
pub fn baracuda_kernels_upsample_nearest_2d_bw_f16_can_implement(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
dout: *const c_void,
dinput: *const c_void,
) -> i32;
/// `upsample_nearest_2d` BW, bf16. # Safety: as f32 BW.
pub fn baracuda_kernels_upsample_nearest_2d_bw_bf16_run(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
dout: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `baracuda_kernels_upsample_nearest_2d_bw_bf16_can_implement` (baracuda kernels upsample nearest 2d bw bf16 can implement).
pub fn baracuda_kernels_upsample_nearest_2d_bw_bf16_can_implement(
N: i32, C: i32, IH: i32, IW: i32, OH: i32, OW: i32,
dout: *const c_void,
dinput: *const c_void,
) -> i32;
}
// ============================================================================
// Phase 19.2 — upsample (bilinear 2D) aliases
// ============================================================================
//
// Thin Rust-side aliases mapping the new `upsample_bilinear_2d_*`
// naming convention to the existing `interpolate_bilinear_2d_*`
// symbols. Same C-ABI, same machine code; the two namespaces co-exist
// so callers on either side of the rename keep working.
//
// Phase 21: dtype coverage extended to {f32, f64, f16, bf16} and the
// signature gained `align_corners` + `scale_h_factor` +
// `scale_w_factor` params (forwarded unchanged to the underlying
// `interpolate_*` symbol).
//
// These are `pub fn` Rust wrappers (`#[inline]`) rather than alias
// symbols to keep `baracuda-kernels-sys` self-contained — downstream
// consumers can `use baracuda_kernels_sys::baracuda_kernels_upsample_bilinear_2d_fw_f32_run`
// uniformly with the rest of the upsample family.
/// Alias for [`baracuda_kernels_interpolate_bilinear_2d_f32_run`] under
/// the new Phase 19.2 `upsample_*` naming convention.
///
/// # Safety
/// As [`baracuda_kernels_interpolate_bilinear_2d_f32_run`].
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
#[inline]
pub unsafe fn baracuda_kernels_upsample_bilinear_2d_fw_f32_run(
n: i32, c: i32, ih: i32, iw: i32, oh: i32, ow: i32,
input: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
align_corners: i32,
scale_h_factor: f64,
scale_w_factor: f64,
stream: *mut c_void,
) -> i32 {
unsafe {
baracuda_kernels_interpolate_bilinear_2d_f32_run(
n, c, ih, iw, oh, ow, input, output, workspace, workspace_bytes,
align_corners, scale_h_factor, scale_w_factor, stream,
)
}
}
/// Alias for [`baracuda_kernels_interpolate_bilinear_2d_f64_run`].
///
/// # Safety
/// As [`baracuda_kernels_interpolate_bilinear_2d_f64_run`].
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
#[inline]
pub unsafe fn baracuda_kernels_upsample_bilinear_2d_fw_f64_run(
n: i32, c: i32, ih: i32, iw: i32, oh: i32, ow: i32,
input: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
align_corners: i32,
scale_h_factor: f64,
scale_w_factor: f64,
stream: *mut c_void,
) -> i32 {
unsafe {
baracuda_kernels_interpolate_bilinear_2d_f64_run(
n, c, ih, iw, oh, ow, input, output, workspace, workspace_bytes,
align_corners, scale_h_factor, scale_w_factor, stream,
)
}
}
/// Alias for [`baracuda_kernels_interpolate_bilinear_2d_f16_run`].
///
/// # Safety
/// As [`baracuda_kernels_interpolate_bilinear_2d_f16_run`].
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
#[inline]
pub unsafe fn baracuda_kernels_upsample_bilinear_2d_fw_f16_run(
n: i32, c: i32, ih: i32, iw: i32, oh: i32, ow: i32,
input: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
align_corners: i32,
scale_h_factor: f64,
scale_w_factor: f64,
stream: *mut c_void,
) -> i32 {
unsafe {
baracuda_kernels_interpolate_bilinear_2d_f16_run(
n, c, ih, iw, oh, ow, input, output, workspace, workspace_bytes,
align_corners, scale_h_factor, scale_w_factor, stream,
)
}
}
/// Alias for [`baracuda_kernels_interpolate_bilinear_2d_bf16_run`].
///
/// # Safety
/// As [`baracuda_kernels_interpolate_bilinear_2d_bf16_run`].
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
#[inline]
pub unsafe fn baracuda_kernels_upsample_bilinear_2d_fw_bf16_run(
n: i32, c: i32, ih: i32, iw: i32, oh: i32, ow: i32,
input: *const c_void,
output: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
align_corners: i32,
scale_h_factor: f64,
scale_w_factor: f64,
stream: *mut c_void,
) -> i32 {
unsafe {
baracuda_kernels_interpolate_bilinear_2d_bf16_run(
n, c, ih, iw, oh, ow, input, output, workspace, workspace_bytes,
align_corners, scale_h_factor, scale_w_factor, stream,
)
}
}
/// Alias for [`baracuda_kernels_interpolate_bilinear_2d_backward_f32_run`].
///
/// # Safety
/// As [`baracuda_kernels_interpolate_bilinear_2d_backward_f32_run`].
/// Caller pre-zeros `dinput`.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
#[inline]
pub unsafe fn baracuda_kernels_upsample_bilinear_2d_bw_f32_run(
n: i32, c: i32, ih: i32, iw: i32, oh: i32, ow: i32,
dout: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
align_corners: i32,
scale_h_factor: f64,
scale_w_factor: f64,
stream: *mut c_void,
) -> i32 {
unsafe {
baracuda_kernels_interpolate_bilinear_2d_backward_f32_run(
n, c, ih, iw, oh, ow, dout, dinput, workspace, workspace_bytes,
align_corners, scale_h_factor, scale_w_factor, stream,
)
}
}
/// Alias for [`baracuda_kernels_interpolate_bilinear_2d_backward_f64_run`].
///
/// # Safety
/// As [`baracuda_kernels_interpolate_bilinear_2d_backward_f64_run`].
/// Caller pre-zeros `dinput`.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
#[inline]
pub unsafe fn baracuda_kernels_upsample_bilinear_2d_bw_f64_run(
n: i32, c: i32, ih: i32, iw: i32, oh: i32, ow: i32,
dout: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
align_corners: i32,
scale_h_factor: f64,
scale_w_factor: f64,
stream: *mut c_void,
) -> i32 {
unsafe {
baracuda_kernels_interpolate_bilinear_2d_backward_f64_run(
n, c, ih, iw, oh, ow, dout, dinput, workspace, workspace_bytes,
align_corners, scale_h_factor, scale_w_factor, stream,
)
}
}
/// Alias for [`baracuda_kernels_interpolate_bilinear_2d_backward_f16_run`].
///
/// # Safety
/// As [`baracuda_kernels_interpolate_bilinear_2d_backward_f16_run`].
/// Caller pre-zeros `dinput`.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
#[inline]
pub unsafe fn baracuda_kernels_upsample_bilinear_2d_bw_f16_run(
n: i32, c: i32, ih: i32, iw: i32, oh: i32, ow: i32,
dout: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
align_corners: i32,
scale_h_factor: f64,
scale_w_factor: f64,
stream: *mut c_void,
) -> i32 {
unsafe {
baracuda_kernels_interpolate_bilinear_2d_backward_f16_run(
n, c, ih, iw, oh, ow, dout, dinput, workspace, workspace_bytes,
align_corners, scale_h_factor, scale_w_factor, stream,
)
}
}
/// Alias for [`baracuda_kernels_interpolate_bilinear_2d_backward_bf16_run`].
///
/// # Safety
/// As [`baracuda_kernels_interpolate_bilinear_2d_backward_bf16_run`].
/// Caller pre-zeros `dinput`.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
#[inline]
pub unsafe fn baracuda_kernels_upsample_bilinear_2d_bw_bf16_run(
n: i32, c: i32, ih: i32, iw: i32, oh: i32, ow: i32,
dout: *const c_void,
dinput: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
align_corners: i32,
scale_h_factor: f64,
scale_w_factor: f64,
stream: *mut c_void,
) -> i32 {
unsafe {
baracuda_kernels_interpolate_bilinear_2d_backward_bf16_run(
n, c, ih, iw, oh, ow, dout, dinput, workspace, workspace_bytes,
align_corners, scale_h_factor, scale_w_factor, stream,
)
}
}
// ============================================================================
// Phase 19.2 — Conv / ConvTranspose cuDNN FFI facade
// ============================================================================
//
// Rust `extern "C"` wrappers around the cuDNN convolution family —
// `Conv{1,2,3}d` + `ConvTranspose{1,2,3}d`, each across {f32, f64,
// f16, bf16} for FW + BW-data + BW-filter (72 symbols total).
// Implemented in `conv_cudnn_facade.rs`; re-exported here so callers
// `use baracuda_kernels_sys::baracuda_kernels_conv_2d_fw_f32_run`
// works uniformly with the rest of the FFI surface.
//
// Gated behind `feature = "cudnn"` — same gate as the cuDNN extern
// block + library link line.
#[cfg(feature = "cudnn")]
mod conv_cudnn_facade;
#[cfg(feature = "cudnn")]
pub use conv_cudnn_facade::*;
// Phase 22 — cuSOLVER linalg FFI facade. Pure-Rust `#[no_mangle]`
// wrappers exposing the cuSOLVER-backed linalg plans (Cholesky / LU /
// QR / SVD / Svd-Batched / Svda-Batched / Eigh / Eig / LstSq / Solve /
// Inverse) as flat C symbols for non-Rust callers (Fuel). No feature
// gate — cuSOLVER ships with the CUDA toolkit (not a separate
// download like cuDNN). See module docs for the handle-lifecycle +
// identity-staging contract.
mod cusolver_facade;
pub use cusolver_facade::*;
// Phase 23 — cuFFT FFT FFI facade. Pure-Rust `#[no_mangle]` wrappers
// exposing the cuFFT-backed FFT family (`fft_1d` / `rfft_1d` /
// `irfft_1d` / `fft_nd` / `rfft_nd` / `irfft_nd`, each × {f32, f64} or
// {c32, c64}) as flat C symbols for non-Rust callers (Fuel). 24
// symbols total. No feature gate — cuFFT ships with the CUDA toolkit.
mod cufft_facade;
pub use cufft_facade::*;
// Phase 23 — cuRAND random-sampling FFI facade. Pure-Rust `#[no_mangle]`
// wrappers exposing the cuRAND-backed pure sampler families
// (`curand_uniform`, `curand_normal`, each × {f32, f64}) as flat C
// symbols for non-Rust callers (Fuel). 8 symbols total. Bernoulli /
// Dropout are composites (cuRAND uniform + bespoke kernel) and ship
// directly under their existing bespoke FFI symbols.
//
// cuRAND itself ships with the CUDA toolkit (no separate download
// like cuDNN), but the facade uses the bespoke
// `baracuda_kernels_affine_inplace_{f32,f64}_run` kernels to remap
// `Uniform(0, 1]` → `Uniform(low, high]`, so it inherits the same
// `sm80 / sm89 / sm90a` gate that bespoke kernel ABI requires.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
mod curand_facade;
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
pub use curand_facade::*;
// Note: cuSPARSE has no Rust plans in `baracuda-kernels` today
// (sparse ops are exposed only via `baracuda-cusparse`'s safe wrapper).
// A cuSPARSE facade lands in a future phase once at least one
// sparse-backed plan exists in `baracuda-kernels` to wrap.
// Phase 24 — CUTLASS GEMM re-export facade. Trampoline `#[no_mangle]`
// wrappers forwarding each of the 162 `baracuda_cutlass_gemm_*` FFI
// symbols (54 SKU families × `_run` / `_workspace_size` /
// `_can_implement`) under the unified `baracuda_kernels_gemm_*` naming
// convention. Covers every Cutlass GEMM SKU shipped by
// `baracuda-cutlass-kernels-sys`: f16/bf16/tf32/f32_simt/f64 × {rcr,
// rrr} × {Identity, Bias, BiasRelu, BiasGelu, BiasSilu} × {f32/i32 bias
// broadcast for int8 / no bias variants} + batched f16/bf16 RCR. Gated
// behind `sm80` / `sm90a` to match the upstream CUTLASS kernel-set
// gates. See module docs for naming convention + status-code +
// workspace contracts.
//
// Skip notes (Phase 24 — no plans, no facade per Phase 23 precedent):
// - cuTENSOR: `baracuda-cutensor` exists, but no `baracuda-kernels`
// plan wraps it (einsum / permute land in a future phase).
// - NPP: `baracuda-npp` exists, but no `baracuda-kernels` plan wraps
// it (image transforms are bespoke + cuDNN-pool today).
// - CV-CUDA: `baracuda-cvcuda` exists, but no `baracuda-kernels` plan
// wraps it.
#[cfg(any(feature = "sm80", feature = "sm90a"))]
mod cutlass_reexport;
#[cfg(any(feature = "sm80", feature = "sm90a"))]
pub use cutlass_reexport::*;
// Phase 74 — dense FP GEMM FFI facade (cuBLAS-backed). Pure-Rust
// `#[no_mangle]` wrappers exposing a plain dense f32/f64/f16/bf16 GEMM
// family (`baracuda_kernels_gemm_dense_*`) with runtime layout tags
// (RRR / RCR / CRR), flexible leading dims, and strided-batch folded
// into the base symbol. Closes the Fuel 2026-06-10 ask — the last
// non-baracuda CUDA surface in Fuel (its own cuBLAS MatMul wrapper).
// No feature gate — cuBLAS ships with the CUDA toolkit and is already
// on this crate's link line. See module docs for the pooled-handle
// lifecycle (deliberate deviation from the transient-handle facade
// convention: GEMM is too hot for per-call create/destroy).
mod gemm_dense_cublas_facade;
pub use gemm_dense_cublas_facade::*;
// =============================================================================
// Phase 53 — bitsandbytes NF4 (NormalFloat 4-bit) dequant + GEMV.
// =============================================================================
//
// Vendored under `vendor/bitsandbytes/` (MIT, Dettmers et al.
// arXiv:2305.14314). Gated behind `feature = "bnb_nf4"` — see the
// Cargo.toml feature block for the rationale.
//
// Packing convention (matches bitsandbytes upstream `Linear4bit`):
// * weight `[N/2, K]` u8 — two 4-bit codes per byte. For byte at
// row `i`, column `k`: low nibble = code for output row `2*i`,
// high nibble = code for output row `2*i+1`.
// * absmax `[N * (K / block_size)]` f32 — per-output-row,
// per-K-block scale. `block_size` typically 64.
// * output `[M, N]` (GEMV) or `[N, K]` (dequant) in T_act ∈
// {f16, bf16}. The f32 dequant FFI exists for the
// quantize→dequant roundtrip smoke test only.
//
// The accumulator stays f32 for every variant. Activation load /
// destination store handle the T_act ↔ f32 cast.
//
// Status codes: `0` success, `2` invalid problem, `5` launch failure.
#[cfg(feature = "bnb_nf4")]
unsafe extern "C" {
// -------- Dequant ----------------------------------------------------
/// NF4 dequantize → `[N, K]` `__half`.
pub fn baracuda_kernels_nf4_dequantize_f16_run(
n: i32, k: i32, block_size: i32,
w_packed: *const c_void,
absmax: *const c_void,
out: *mut c_void,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_nf4_dequantize_f16_can_implement(
n: i32, k: i32, block_size: i32,
) -> i32;
/// NF4 dequantize → `[N, K]` `__nv_bfloat16`.
pub fn baracuda_kernels_nf4_dequantize_bf16_run(
n: i32, k: i32, block_size: i32,
w_packed: *const c_void,
absmax: *const c_void,
out: *mut c_void,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_nf4_dequantize_bf16_can_implement(
n: i32, k: i32, block_size: i32,
) -> i32;
/// NF4 dequantize → `[N, K]` `f32`. Smoke-test path only.
pub fn baracuda_kernels_nf4_dequantize_f32_run(
n: i32, k: i32, block_size: i32,
w_packed: *const c_void,
absmax: *const c_void,
out: *mut c_void,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_nf4_dequantize_f32_can_implement(
n: i32, k: i32, block_size: i32,
) -> i32;
// -------- GEMV M=1 ---------------------------------------------------
/// NF4 W4A16 GEMV (M=1, single decode vector), f16 activation.
pub fn baracuda_kernels_nf4_gemv_m1_f16_run(
n: i32, k: i32, block_size: i32,
w_packed: *const c_void,
absmax: *const c_void,
y: *const c_void,
out: *mut c_void,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_nf4_gemv_m1_f16_can_implement(
n: i32, k: i32, block_size: i32,
) -> i32;
/// NF4 W4A16 GEMV (M=1), bf16 activation.
pub fn baracuda_kernels_nf4_gemv_m1_bf16_run(
n: i32, k: i32, block_size: i32,
w_packed: *const c_void,
absmax: *const c_void,
y: *const c_void,
out: *mut c_void,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_nf4_gemv_m1_bf16_can_implement(
n: i32, k: i32, block_size: i32,
) -> i32;
// -------- GEMV multi-M (compile-time M ∈ {2, 4, 8}) ------------------
/// NF4 W4A16 GEMV multi-M=2, f16 activation. Output `[2, N]`.
pub fn baracuda_kernels_nf4_gemv_m2_f16_run(
n: i32, k: i32, block_size: i32,
w_packed: *const c_void,
absmax: *const c_void,
y: *const c_void,
out: *mut c_void,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_nf4_gemv_m2_f16_can_implement(
n: i32, k: i32, block_size: i32,
) -> i32;
/// NF4 W4A16 GEMV multi-M=4, f16 activation. Output `[4, N]`.
pub fn baracuda_kernels_nf4_gemv_m4_f16_run(
n: i32, k: i32, block_size: i32,
w_packed: *const c_void,
absmax: *const c_void,
y: *const c_void,
out: *mut c_void,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_nf4_gemv_m4_f16_can_implement(
n: i32, k: i32, block_size: i32,
) -> i32;
/// NF4 W4A16 GEMV multi-M=8, f16 activation. Output `[8, N]`.
pub fn baracuda_kernels_nf4_gemv_m8_f16_run(
n: i32, k: i32, block_size: i32,
w_packed: *const c_void,
absmax: *const c_void,
y: *const c_void,
out: *mut c_void,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_nf4_gemv_m8_f16_can_implement(
n: i32, k: i32, block_size: i32,
) -> i32;
/// NF4 W4A16 GEMV multi-M=2, bf16 activation.
pub fn baracuda_kernels_nf4_gemv_m2_bf16_run(
n: i32, k: i32, block_size: i32,
w_packed: *const c_void,
absmax: *const c_void,
y: *const c_void,
out: *mut c_void,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_nf4_gemv_m2_bf16_can_implement(
n: i32, k: i32, block_size: i32,
) -> i32;
/// NF4 W4A16 GEMV multi-M=4, bf16 activation.
pub fn baracuda_kernels_nf4_gemv_m4_bf16_run(
n: i32, k: i32, block_size: i32,
w_packed: *const c_void,
absmax: *const c_void,
y: *const c_void,
out: *mut c_void,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_nf4_gemv_m4_bf16_can_implement(
n: i32, k: i32, block_size: i32,
) -> i32;
/// NF4 W4A16 GEMV multi-M=8, bf16 activation.
pub fn baracuda_kernels_nf4_gemv_m8_bf16_run(
n: i32, k: i32, block_size: i32,
w_packed: *const c_void,
absmax: *const c_void,
y: *const c_void,
out: *mut c_void,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_nf4_gemv_m8_bf16_can_implement(
n: i32, k: i32, block_size: i32,
) -> i32;
}
// =============================================================================
// Phase 54 — xFormers cherry-pick: BlockSparseAttention (BSD-3-Clause).
// =============================================================================
//
// Block-sparse SDPA FW where the attention mask is a per-block boolean
// pattern `[B, H, num_blocks_q * num_blocks_k]` (uint8_t). Only the
// active (q_block, k_block) pairs participate in the QK^T matmul +
// online-softmax accumulation. Masked blocks are SKIPPED entirely (no
// K/V load, no compute) — real wall-clock speedup on long-context
// attention with known sparse patterns.
//
// Algorithmic reference: facebookresearch/xformers
// `components/attention/blocksparse.py`. baracuda's kernel is a clean-
// room hand-port that reuses the Phase 6.6 online-softmax tile pipeline
// (see `kernels/include/baracuda_sdpa_block_sparse.cuh`).
//
// Layout contract (rank-4, contiguous, row-major):
// Q : [B, H, Q_len, D_k]
// K : [B, H, K_len, D_k]
// V : [B, H, K_len, D_v]
// y : [B, H, Q_len, D_v]
// lse : [B, H, Q_len]
// block_pattern : [B, H, num_blocks_q * num_blocks_k] (uint8_t)
//
// Tier-1 constraints:
// block_size ∈ [1, 64]
// d_k = d_v ≤ 128
// FW only (no BW); causal mask supported (composes with the block
// pattern — masked blocks AND causal-suppressed cells are skipped).
//
// Symbols gated behind the `xformers_blocksparse` cargo feature.
#[cfg(feature = "xformers_blocksparse")]
unsafe extern "C" {
pub fn baracuda_kernels_sdpa_f32_block_sparse_run(
batch: i32, heads: i32, q_len: i32, k_len: i32,
d_k: i32, d_v: i32, block_size: i32,
scale: f32, is_causal: i32,
q: *const c_void, k: *const c_void, v: *const c_void,
block_pattern: *const c_void,
y: *mut c_void, lse: *mut c_void,
workspace: *mut c_void, workspace_bytes: u64,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_sdpa_f16_block_sparse_run(
batch: i32, heads: i32, q_len: i32, k_len: i32,
d_k: i32, d_v: i32, block_size: i32,
scale: f32, is_causal: i32,
q: *const c_void, k: *const c_void, v: *const c_void,
block_pattern: *const c_void,
y: *mut c_void, lse: *mut c_void,
workspace: *mut c_void, workspace_bytes: u64,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_sdpa_bf16_block_sparse_run(
batch: i32, heads: i32, q_len: i32, k_len: i32,
d_k: i32, d_v: i32, block_size: i32,
scale: f32, is_causal: i32,
q: *const c_void, k: *const c_void, v: *const c_void,
block_pattern: *const c_void,
y: *mut c_void, lse: *mut c_void,
workspace: *mut c_void, workspace_bytes: u64,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_sdpa_f64_block_sparse_run(
batch: i32, heads: i32, q_len: i32, k_len: i32,
d_k: i32, d_v: i32, block_size: i32,
scale: f32, is_causal: i32,
q: *const c_void, k: *const c_void, v: *const c_void,
block_pattern: *const c_void,
y: *mut c_void, lse: *mut c_void,
workspace: *mut c_void, workspace_bytes: u64,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_sdpa_f32_block_sparse_can_implement(
batch: i32, heads: i32, q_len: i32, k_len: i32,
d_k: i32, d_v: i32, block_size: i32,
) -> i32;
pub fn baracuda_kernels_sdpa_f16_block_sparse_can_implement(
batch: i32, heads: i32, q_len: i32, k_len: i32,
d_k: i32, d_v: i32, block_size: i32,
) -> i32;
pub fn baracuda_kernels_sdpa_bf16_block_sparse_can_implement(
batch: i32, heads: i32, q_len: i32, k_len: i32,
d_k: i32, d_v: i32, block_size: i32,
) -> i32;
pub fn baracuda_kernels_sdpa_f64_block_sparse_can_implement(
batch: i32, heads: i32, q_len: i32, k_len: i32,
d_k: i32, d_v: i32, block_size: i32,
) -> i32;
}
// =============================================================================
// Phase 54 — xFormers cherry-pick: 2:4 Structured Sparsity GEMM
// (BSD-3-Clause).
// =============================================================================
//
// 2:4 pattern: in every 4 consecutive weight cells, AT MOST 2 are
// non-zero. Compressed format:
// W_compressed: [M, K/2] of dtype T
// W_metadata: [M, K/8] of uint16_t (2 bytes; each byte encodes one
// 4-group's 2 non-zero positions — low 2 bits = pos0,
// bits [2:3] = pos1).
//
// Output: Y[N, M] = X[N, K] @ W_dense^T (where W_dense is the inflated
// [M, K] tensor from W_compressed + W_metadata).
//
// Algorithmic reference: facebookresearch/xformers `sparse24/`. baracuda's
// kernel is a clean-room hand-port (see
// `kernels/include/baracuda_gemm_sparse24.cuh`).
//
// Tier-1 implementation: **inflate-then-dense-matmul** path.
// Sparse-tensor-core (`mma.sp.sync.aligned`) hardware speedup deferred
// to Tier 2 alongside cuSPARSELt integration.
//
// Layout contract: all row-major contiguous.
// K must be a multiple of 8 (one uint16 metadata covers 2 4-groups
// each via the low/high byte).
//
// Workspace: `M * K * sizeof(T)` bytes for the inflated dense W tile.
//
// Symbols gated behind the `xformers_sparse24` cargo feature.
#[cfg(feature = "xformers_sparse24")]
unsafe extern "C" {
/// Inflate the compressed sparse-24 weight to the dense `[M, K]`
/// representation in caller-owned memory. Use this when you want
/// to drive the matmul through your own GEMM kernel afterwards.
pub fn baracuda_kernels_gemm_f32_sparse24_inflate(
m: i32, k: i32,
w_compressed: *const c_void, w_metadata: *const c_void,
w_dense: *mut c_void,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_gemm_f16_sparse24_inflate(
m: i32, k: i32,
w_compressed: *const c_void, w_metadata: *const c_void,
w_dense: *mut c_void,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_gemm_bf16_sparse24_inflate(
m: i32, k: i32,
w_compressed: *const c_void, w_metadata: *const c_void,
w_dense: *mut c_void,
stream: *mut c_void,
) -> i32;
/// Reference end-to-end GEMM with on-the-fly inflation through the
/// caller-supplied workspace (size = `M * K * sizeof(T)` bytes).
/// Output Y[N, M] = X[N, K] @ inflate(W_compressed)^T.
pub fn baracuda_kernels_gemm_f32_sparse24_gemm_run(
n: i32, m: i32, k: i32,
x: *const c_void, w_compressed: *const c_void, w_metadata: *const c_void,
y: *mut c_void,
workspace: *mut c_void, workspace_bytes: u64,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_gemm_f16_sparse24_gemm_run(
n: i32, m: i32, k: i32,
x: *const c_void, w_compressed: *const c_void, w_metadata: *const c_void,
y: *mut c_void,
workspace: *mut c_void, workspace_bytes: u64,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_gemm_bf16_sparse24_gemm_run(
n: i32, m: i32, k: i32,
x: *const c_void, w_compressed: *const c_void, w_metadata: *const c_void,
y: *mut c_void,
workspace: *mut c_void, workspace_bytes: u64,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_gemm_f32_sparse24_gemm_can_implement(
n: i32, m: i32, k: i32,
) -> i32;
pub fn baracuda_kernels_gemm_f16_sparse24_gemm_can_implement(
n: i32, m: i32, k: i32,
) -> i32;
pub fn baracuda_kernels_gemm_bf16_sparse24_gemm_can_implement(
n: i32, m: i32, k: i32,
) -> i32;
pub fn baracuda_kernels_gemm_f32_sparse24_gemm_workspace_bytes(
n: i32, m: i32, k: i32,
) -> u64;
pub fn baracuda_kernels_gemm_f16_sparse24_gemm_workspace_bytes(
n: i32, m: i32, k: i32,
) -> u64;
pub fn baracuda_kernels_gemm_bf16_sparse24_gemm_workspace_bytes(
n: i32, m: i32, k: i32,
) -> u64;
}
// =============================================================================
// Phase 50 — Dao-AILab causal-conv1d (BSD-3-Clause) + state-spaces/mamba SSD
// chunk-scan (Apache-2.0). Phase 50b — Mamba-1 selective_scan sibling. All
// gated behind the `mamba` cargo feature.
// =============================================================================
//
// Causal-conv1d: depthwise causal cross-correlation primitive used between
// the Mamba input projection and the SSM block. Trailblazer constraints:
// - widths W ∈ {2, 3, 4}
// - activation: SiLU or identity (use_silu = 1 / 0)
// - dtypes: f32 / f16 / bf16 / f64
// - layout: NCL contiguous
// - FW deterministic / bit-stable; BW dw/db atomic-accumulated.
//
// SSD chunk-scan: Mamba-2 selective SSM via the State-Space Duality
// reformulation. Per-(b, h) sequential recurrence with state in SMEM.
// Trailblazer constraints: dtypes f32 / f16 / bf16; head_dim/state_dim
// ≤ 256 for FW, ≤ 64 for BW. BW workspace = `B * H * L * D * N` f32
// recorded states (query via `_workspace_bytes`).
//
// Selective_scan: Mamba-1 selective SSM. Per-(b, d) sequential recurrence
// with per-(d, n) state matrix in SMEM. Trailblazer constraints: dtypes
// f32 / f16 / bf16; `dstate` (N) ≤ 256. BW workspace =
// `B * D * L * N * sizeof(T)` (query via `_workspace_bytes`).
//
// Status codes match baracuda convention: 0 ok, 2 invalid_problem,
// 3 unsupported, 4 workspace_too_small, 5 launch_failure.
#[cfg(feature = "mamba")]
unsafe extern "C" {
// -----------------------------------------------------------------
// causal-conv1d FW (Phase 50)
// -----------------------------------------------------------------
pub fn baracuda_kernels_causal_conv1d_f32_run(
batch: i32, channels: i32, seqlen: i32, width: i32,
use_silu: i32,
x: *const c_void, weight: *const c_void, bias: *const c_void,
y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_causal_conv1d_f16_run(
batch: i32, channels: i32, seqlen: i32, width: i32,
use_silu: i32,
x: *const c_void, weight: *const c_void, bias: *const c_void,
y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_causal_conv1d_bf16_run(
batch: i32, channels: i32, seqlen: i32, width: i32,
use_silu: i32,
x: *const c_void, weight: *const c_void, bias: *const c_void,
y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_causal_conv1d_f64_run(
batch: i32, channels: i32, seqlen: i32, width: i32,
use_silu: i32,
x: *const c_void, weight: *const c_void, bias: *const c_void,
y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_causal_conv1d_f32_can_implement(
batch: i32, channels: i32, seqlen: i32, width: i32,
) -> i32;
pub fn baracuda_kernels_causal_conv1d_f16_can_implement(
batch: i32, channels: i32, seqlen: i32, width: i32,
) -> i32;
pub fn baracuda_kernels_causal_conv1d_bf16_can_implement(
batch: i32, channels: i32, seqlen: i32, width: i32,
) -> i32;
pub fn baracuda_kernels_causal_conv1d_f64_can_implement(
batch: i32, channels: i32, seqlen: i32, width: i32,
) -> i32;
// -----------------------------------------------------------------
// causal-conv1d BW (Phase 50)
// -----------------------------------------------------------------
pub fn baracuda_kernels_causal_conv1d_f32_backward_run(
batch: i32, channels: i32, seqlen: i32, width: i32,
use_silu: i32,
x: *const c_void, weight: *const c_void, bias: *const c_void, dy: *const c_void,
dx: *mut c_void, dw: *mut c_void, db: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_causal_conv1d_f32_backward_can_implement(
batch: i32, channels: i32, seqlen: i32, width: i32, use_silu: i32,
) -> i32;
pub fn baracuda_kernels_causal_conv1d_f16_backward_run(
batch: i32, channels: i32, seqlen: i32, width: i32,
use_silu: i32,
x: *const c_void, weight: *const c_void, bias: *const c_void, dy: *const c_void,
dx: *mut c_void, dw: *mut c_void, db: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_causal_conv1d_f16_backward_can_implement(
batch: i32, channels: i32, seqlen: i32, width: i32, use_silu: i32,
) -> i32;
pub fn baracuda_kernels_causal_conv1d_bf16_backward_run(
batch: i32, channels: i32, seqlen: i32, width: i32,
use_silu: i32,
x: *const c_void, weight: *const c_void, bias: *const c_void, dy: *const c_void,
dx: *mut c_void, dw: *mut c_void, db: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_causal_conv1d_bf16_backward_can_implement(
batch: i32, channels: i32, seqlen: i32, width: i32, use_silu: i32,
) -> i32;
pub fn baracuda_kernels_causal_conv1d_f64_backward_run(
batch: i32, channels: i32, seqlen: i32, width: i32,
use_silu: i32,
x: *const c_void, weight: *const c_void, bias: *const c_void, dy: *const c_void,
dx: *mut c_void, dw: *mut c_void, db: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_causal_conv1d_f64_backward_can_implement(
batch: i32, channels: i32, seqlen: i32, width: i32, use_silu: i32,
) -> i32;
// -----------------------------------------------------------------
// SSD chunk-scan FW (Phase 50)
// -----------------------------------------------------------------
pub fn baracuda_kernels_ssd_chunk_scan_f32_run(
batch: i32, seqlen: i32, heads: i32,
head_dim: i32, state_dim: i32, chunk_size: i32,
x: *const c_void, dt: *const c_void, a: *const c_void,
b: *const c_void, c: *const c_void,
y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_ssd_chunk_scan_f16_run(
batch: i32, seqlen: i32, heads: i32,
head_dim: i32, state_dim: i32, chunk_size: i32,
x: *const c_void, dt: *const c_void, a: *const c_void,
b: *const c_void, c: *const c_void,
y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_ssd_chunk_scan_bf16_run(
batch: i32, seqlen: i32, heads: i32,
head_dim: i32, state_dim: i32, chunk_size: i32,
x: *const c_void, dt: *const c_void, a: *const c_void,
b: *const c_void, c: *const c_void,
y: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_ssd_chunk_scan_f32_can_implement(
batch: i32, seqlen: i32, heads: i32,
head_dim: i32, state_dim: i32, chunk_size: i32,
) -> i32;
pub fn baracuda_kernels_ssd_chunk_scan_f16_can_implement(
batch: i32, seqlen: i32, heads: i32,
head_dim: i32, state_dim: i32, chunk_size: i32,
) -> i32;
pub fn baracuda_kernels_ssd_chunk_scan_bf16_can_implement(
batch: i32, seqlen: i32, heads: i32,
head_dim: i32, state_dim: i32, chunk_size: i32,
) -> i32;
// -----------------------------------------------------------------
// SSD chunk-scan BW (Phase 50)
// -----------------------------------------------------------------
pub fn baracuda_kernels_ssd_chunk_scan_workspace_bytes(
batch: i32, seqlen: i32, heads: i32,
head_dim: i32, state_dim: i32,
chunk_size: i32, dtype_id: i32,
) -> usize;
pub fn baracuda_kernels_ssd_chunk_scan_f32_backward_run(
batch: i32, seqlen: i32, heads: i32,
head_dim: i32, state_dim: i32, chunk_size: i32,
x: *const c_void, dt: *const c_void, a: *const c_void,
b: *const c_void, c: *const c_void, dy: *const c_void,
dx: *mut c_void, d_b: *mut c_void, d_c: *mut c_void,
d_dt: *mut c_void, d_a: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_ssd_chunk_scan_f32_backward_can_implement(
batch: i32, seqlen: i32, heads: i32, head_dim: i32, state_dim: i32, chunk_size: i32,
) -> i32;
pub fn baracuda_kernels_ssd_chunk_scan_f16_backward_run(
batch: i32, seqlen: i32, heads: i32,
head_dim: i32, state_dim: i32, chunk_size: i32,
x: *const c_void, dt: *const c_void, a: *const c_void,
b: *const c_void, c: *const c_void, dy: *const c_void,
dx: *mut c_void, d_b: *mut c_void, d_c: *mut c_void,
d_dt: *mut c_void, d_a: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_ssd_chunk_scan_f16_backward_can_implement(
batch: i32, seqlen: i32, heads: i32, head_dim: i32, state_dim: i32, chunk_size: i32,
) -> i32;
pub fn baracuda_kernels_ssd_chunk_scan_bf16_backward_run(
batch: i32, seqlen: i32, heads: i32,
head_dim: i32, state_dim: i32, chunk_size: i32,
x: *const c_void, dt: *const c_void, a: *const c_void,
b: *const c_void, c: *const c_void, dy: *const c_void,
dx: *mut c_void, d_b: *mut c_void, d_c: *mut c_void,
d_dt: *mut c_void, d_a: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_ssd_chunk_scan_bf16_backward_can_implement(
batch: i32, seqlen: i32, heads: i32, head_dim: i32, state_dim: i32, chunk_size: i32,
) -> i32;
// -----------------------------------------------------------------
// selective_scan FW (Phase 50b)
// -----------------------------------------------------------------
pub fn baracuda_kernels_selective_scan_f32_run(
batch: i32, seqlen: i32, dim: i32, dstate: i32,
delta_softplus: i32,
u: *const c_void, delta: *const c_void, a: *const c_void,
b: *const c_void, c: *const c_void,
d_skip: *const c_void, z: *const c_void, delta_bias: *const c_void,
y: *mut c_void, last_state: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_selective_scan_f16_run(
batch: i32, seqlen: i32, dim: i32, dstate: i32,
delta_softplus: i32,
u: *const c_void, delta: *const c_void, a: *const c_void,
b: *const c_void, c: *const c_void,
d_skip: *const c_void, z: *const c_void, delta_bias: *const c_void,
y: *mut c_void, last_state: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_selective_scan_bf16_run(
batch: i32, seqlen: i32, dim: i32, dstate: i32,
delta_softplus: i32,
u: *const c_void, delta: *const c_void, a: *const c_void,
b: *const c_void, c: *const c_void,
d_skip: *const c_void, z: *const c_void, delta_bias: *const c_void,
y: *mut c_void, last_state: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_selective_scan_f32_can_implement(
batch: i32, seqlen: i32, dim: i32, dstate: i32,
) -> i32;
pub fn baracuda_kernels_selective_scan_f16_can_implement(
batch: i32, seqlen: i32, dim: i32, dstate: i32,
) -> i32;
pub fn baracuda_kernels_selective_scan_bf16_can_implement(
batch: i32, seqlen: i32, dim: i32, dstate: i32,
) -> i32;
// -----------------------------------------------------------------
// selective_scan BW (Phase 50b)
// -----------------------------------------------------------------
pub fn baracuda_kernels_selective_scan_workspace_bytes(
batch: i32, seqlen: i32, dim: i32, dstate: i32,
dtype_id: i32,
) -> usize;
pub fn baracuda_kernels_selective_scan_f32_backward_run(
batch: i32, seqlen: i32, dim: i32, dstate: i32,
delta_softplus: i32,
u: *const c_void, delta: *const c_void, a: *const c_void,
b: *const c_void, c: *const c_void,
d_skip: *const c_void, z: *const c_void, delta_bias: *const c_void,
dy: *const c_void,
du: *mut c_void, d_b: *mut c_void, d_c: *mut c_void, d_delta: *mut c_void,
d_a: *mut c_void, d_d: *mut c_void, dz: *mut c_void, d_delta_bias: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_selective_scan_f32_backward_can_implement(
batch: i32, seqlen: i32, dim: i32, dstate: i32,
) -> i32;
pub fn baracuda_kernels_selective_scan_f16_backward_run(
batch: i32, seqlen: i32, dim: i32, dstate: i32,
delta_softplus: i32,
u: *const c_void, delta: *const c_void, a: *const c_void,
b: *const c_void, c: *const c_void,
d_skip: *const c_void, z: *const c_void, delta_bias: *const c_void,
dy: *const c_void,
du: *mut c_void, d_b: *mut c_void, d_c: *mut c_void, d_delta: *mut c_void,
d_a: *mut c_void, d_d: *mut c_void, dz: *mut c_void, d_delta_bias: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_selective_scan_f16_backward_can_implement(
batch: i32, seqlen: i32, dim: i32, dstate: i32,
) -> i32;
pub fn baracuda_kernels_selective_scan_bf16_backward_run(
batch: i32, seqlen: i32, dim: i32, dstate: i32,
delta_softplus: i32,
u: *const c_void, delta: *const c_void, a: *const c_void,
b: *const c_void, c: *const c_void,
d_skip: *const c_void, z: *const c_void, delta_bias: *const c_void,
dy: *const c_void,
du: *mut c_void, d_b: *mut c_void, d_c: *mut c_void, d_delta: *mut c_void,
d_a: *mut c_void, d_d: *mut c_void, dz: *mut c_void, d_delta_bias: *mut c_void,
workspace: *mut c_void, workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_selective_scan_bf16_backward_can_implement(
batch: i32, seqlen: i32, dim: i32, dstate: i32,
) -> i32;
}
// ============================================================
// Phase 48 (Consolidation B) — Marlin + AWQ 4-bit GEMM
// ============================================================
//
// Two vendored 4-bit GEMM kernels with complementary scope:
//
// * Marlin (IST-DASLab, Apache-2.0 + §3 patent grant) — **symmetric**
// int4 W4A16 GEMM. Reports ~3.87x speedup over FP16 GEMM at
// batch sizes 1-32 on Ampere / Ada GPUs. Targets sm_80 / sm_86 /
// sm_89; NOT sm_90 (Hopper). Goal A of Phase 48.
// * AWQ (mit-han-lab, MIT — no patent grant) — **asymmetric** int4
// W4A16 GEMM with explicit per-group zero-points. Loads
// directly from HF `*-AWQ` checkpoints without repack. Goal B
// of Phase 48.
//
// The asymmetric→symmetric bridge (loading GPTQ-format weights into
// Marlin) is a host-side repack utility in
// `baracuda-kernels::gemm::gptq_to_marlin` (Goal C of Phase 48). It
// is a pure-Rust algorithmic routine and does not surface here.
//
// Vendored sources at `vendor/marlin/` and `vendor/awq/`. Each ships
// a `VENDOR.md` documenting the upstream commit pin, license terms,
// and the scope of the vendored slice. Gated behind the `marlin` and
// `awq` cargo features respectively (both default OFF).
//
// Status codes (matches the rest of the surface):
// 0 = success
// 2 = invalid problem (bad alignment, group size, ...)
// 3 = unsupported configuration (e.g. AWQ dequant stub)
// 4 = workspace too small (AWQ only)
// 5 = launch failure
#[cfg(feature = "marlin")]
unsafe extern "C" {
/// Marlin W4A16 GEMM — symmetric int4 weights, fp16 activation +
/// output. M can be any non-negative value; the kernel internally
/// tiles into M-block groups (16 rows / block, ≤ 4 blocks per
/// kernel launch, ≤ `max_par` parallel kernel launches per outer
/// iteration).
///
/// `B` is `[K/16, N*16/8]` `int32` pre-shuffled int4 weight tile.
/// `scales` is `[K/groupsize, N]` `__half` per-group scales (or
/// `[1, N]` for `groupsize == -1`), pre-permuted by the packer.
/// `workspace` must be a zero-initialised `int32` buffer with
/// `>= (N / 128) * max_par` entries.
/// `groupsize ∈ {-1, 128}`; `max_par` is the parallel-tile upper
/// bound (typical 16, matching upstream).
pub fn baracuda_kernels_int4_marlin_gemm_f16_run(
M: i32, N: i32, K: i32,
a: *const c_void,
b: *const c_void,
c: *mut c_void,
scales: *const c_void,
workspace: *mut c_void,
groupsize: i32,
max_par: i32,
stream: *mut c_void,
) -> i32;
/// Marlin GEMM shape/alignment validator (no kernel launch).
/// Returns 0 if the (M, N, K, groupsize) tuple is in the
/// supported range; 2 otherwise.
pub fn baracuda_kernels_int4_marlin_gemm_f16_can_implement(
M: i32, N: i32, K: i32,
groupsize: i32,
) -> i32;
}
#[cfg(feature = "awq")]
unsafe extern "C" {
/// AWQ W4A16 GEMM — asymmetric int4 with explicit per-group
/// zero-points, fp16 activation + output, f32 accumulator.
///
/// `in_feats` is `[M, IC]` row-major `__half`.
/// `kernel_weights` is `[OC, IC/8]` `int32` packed int4 (OC-major,
/// IC-minor; transpose of the naive `[K, N]`).
/// `scaling_factors` is `[IC/group_size, OC]` `__half`.
/// `zeros` is `[IC/group_size, OC/8]` `int32` packed int4
/// zero-points.
/// `out` is `[M, OC]` row-major `__half` output.
/// `workspace` is `[split_k_iters, padded_M, OC]` `__half` staging
/// for the split-k partial sums (`padded_M = ceil(M, 128) * 128`).
/// Size in bytes via
/// [`baracuda_kernels_int4_awq_gemm_f16_workspace_bytes`].
/// `group_size ∈ {64, 128}`. `split_k_iters` is caller-chosen;
/// typical 8.
pub fn baracuda_kernels_int4_awq_gemm_f16_run(
M: i32, IC: i32, OC: i32,
group_size: i32, split_k_iters: i32,
in_feats: *const c_void,
kernel_weights: *const c_void,
scaling_factors: *const c_void,
zeros: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// AWQ GEMM workspace-size query. Returns the staging-buffer
/// requirement in bytes for `(M, OC, split_k_iters)`. Returns 0
/// for any non-positive input.
pub fn baracuda_kernels_int4_awq_gemm_f16_workspace_bytes(
M: i32, OC: i32, split_k_iters: i32,
) -> usize;
/// AWQ GEMM shape/alignment validator (no kernel launch).
pub fn baracuda_kernels_int4_awq_gemm_f16_can_implement(
M: i32, IC: i32, OC: i32,
group_size: i32, split_k_iters: i32,
) -> i32;
/// AWQ dequant stub. AWQ does not ship a standalone dequant
/// kernel upstream (the dequant lives inside the GEMM as a
/// per-tile staging step). This entry point always returns 3
/// (unsupported); kept for FFI-surface symmetry with the other
/// 4-bit families. Future expansion may add a real dequant
/// (synthesised from the GEMM with identity activations).
pub fn baracuda_kernels_int4_awq_dequantize_f16_run(
N: i32, K: i32, group_size: i32,
kernel_weights: *const c_void,
scaling_factors: *const c_void,
zeros: *const c_void,
out: *mut c_void,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_int4_awq_dequantize_f16_can_implement(
N: i32, K: i32, group_size: i32,
kernel_weights: *const c_void,
scaling_factors: *const c_void,
zeros: *const c_void,
out: *const c_void,
) -> i32;
}
// ============================================================
// Phase 56 — Ring Attention (sequence-parallel attention)
// ============================================================
//
// Clean-room CUDA port of the Liu/Yan/Abbeel 2023 Ring Attention
// algorithm (arXiv:2310.01889; JAX reference at
// https://github.com/lhao499/RingAttention, Apache-2.0). The Rust plan
// in `baracuda-kernels::attention::ring_attention` orchestrates the
// NCCL-based K/V chunk rotation across ranks; the kernels below
// implement the per-step partial attention (online-softmax fold of the
// resident K/V chunk into the persistent (o_acc, m_acc, l_acc) state)
// plus a finalize kernel that emits the final `y` from the accumulated
// state after all rotation steps complete.
//
// Tier 1 dtype set: f16, bf16. f32 / f64 deferred (Tier 2). Tier 1
// fixes `head_dim = 128` (the launchers reject anything else).
//
// Status codes match the existing family: 0 success / 1 misaligned /
// 2 invalid problem / 3 unsupported / 4 workspace too small /
// 5 internal launch failure (+1000 for raw cudaError_t propagation).
#[cfg(feature = "ring_attention")]
unsafe extern "C" {
/// Workspace bytes for the Ring Attention persistent accumulator
/// state — `(o_acc + m_acc + l_acc)` in f32.
///
/// `o_acc`: `f32[batch, heads, q_local, d]`
/// `m_acc`: `f32[batch, heads, q_local]`
/// `l_acc`: `f32[batch, heads, q_local]`
pub fn baracuda_kernels_ring_attention_workspace_bytes(
batch: i32, heads: i32, q_local: i32, d: i32,
) -> usize;
/// Dtype-independent init helper. Sets `o_acc = 0`, `m_acc = -INF`,
/// `l_acc = 0`. Must be called before the first step kernel.
pub fn baracuda_kernels_ring_attention_init_run(
o_acc: *mut c_void,
m_acc: *mut c_void,
l_acc: *mut c_void,
o_len: i64,
ml_len: i64,
stream: *mut c_void,
) -> i32;
pub fn baracuda_kernels_ring_attention_init_can_implement(
o_acc: *const c_void,
m_acc: *const c_void,
l_acc: *const c_void,
o_len: i64,
ml_len: i64,
) -> i32;
// ---------- f16 ----------
/// Ring Attention step (f16). Folds the resident K/V chunk's
/// contribution into the persistent state via online-softmax.
/// Call once per rotation step.
pub fn baracuda_kernels_ring_attention_f16_step_run(
batch: i32,
heads: i32,
q_local: i32,
k_chunk: i32,
d: i32,
q_global_base: i32,
k_global_base: i32,
scale: f32,
is_causal: i32,
q: *const c_void,
k_local: *const c_void,
v_local: *const c_void,
o_acc: *mut c_void,
m_acc: *mut c_void,
l_acc: *mut c_void,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for the f16 step kernel.
pub fn baracuda_kernels_ring_attention_f16_step_can_implement(
batch: i32, heads: i32, q_local: i32, k_chunk: i32, d: i32,
) -> i32;
/// Ring Attention finalize (f16). Divides the persistent
/// `o_acc` by `l_acc` and writes the final `y` in operand dtype.
/// Optionally emits `lse = m + log(l)` (pass `lse = null` to skip).
/// Call once, after the last step kernel.
pub fn baracuda_kernels_ring_attention_f16_finalize_run(
batch: i32,
heads: i32,
q_local: i32,
d: i32,
o_acc: *const c_void,
m_acc: *const c_void,
l_acc: *const c_void,
y: *mut c_void,
lse: *mut c_void,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for the f16 finalize kernel.
pub fn baracuda_kernels_ring_attention_f16_finalize_can_implement(
batch: i32, heads: i32, q_local: i32, d: i32,
) -> i32;
// ---------- bf16 ----------
/// Ring Attention step (bf16). See `..._f16_step_run` for the
/// algorithm. f32 accumulators throughout.
pub fn baracuda_kernels_ring_attention_bf16_step_run(
batch: i32,
heads: i32,
q_local: i32,
k_chunk: i32,
d: i32,
q_global_base: i32,
k_global_base: i32,
scale: f32,
is_causal: i32,
q: *const c_void,
k_local: *const c_void,
v_local: *const c_void,
o_acc: *mut c_void,
m_acc: *mut c_void,
l_acc: *mut c_void,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for the bf16 step kernel.
pub fn baracuda_kernels_ring_attention_bf16_step_can_implement(
batch: i32, heads: i32, q_local: i32, k_chunk: i32, d: i32,
) -> i32;
/// Ring Attention finalize (bf16). See `..._f16_finalize_run`.
pub fn baracuda_kernels_ring_attention_bf16_finalize_run(
batch: i32,
heads: i32,
q_local: i32,
d: i32,
o_acc: *const c_void,
m_acc: *const c_void,
l_acc: *const c_void,
y: *mut c_void,
lse: *mut c_void,
stream: *mut c_void,
) -> i32;
/// Pre-launch implementability check for the bf16 finalize kernel.
pub fn baracuda_kernels_ring_attention_bf16_finalize_can_implement(
batch: i32, heads: i32, q_local: i32, d: i32,
) -> i32;
}
// ============================================================================
// Phase 59b — FA2 BW pass + varlen (FW + BW)
// ============================================================================
//
// Closes the Fuel FA2-retirement requirements by adding:
//
// 1. BW symbols for every head_dim FA2 v2.8.3 ships
// ({32, 64, 96, 128, 192, 256}), per dtype (fp16, bf16).
// 2. Varlen FW + BW: packed Q/K/V/O across heterogeneous sequences
// via `cu_seqlens_q` / `cu_seqlens_k` index tensors.
//
// FA2 v2.8.3's BW and varlen do NOT have separate .cu file families —
// they reuse the same per-(headdim, dtype, causal) instantiations as
// FW; varlen is gated by a runtime `params.cu_seqlens_* != nullptr`
// check inside the BW launch template. As a result Phase 59b adds 24
// new BW .cu source files (mirroring the Phase 59a FW set) plus two
// new launcher TUs (`fa2_backward_launcher.cu`, `fa2_varlen_launcher.cu`),
// but NO separate "varlen .cu" family.
//
// **BW workspace contract**: FA2 BW needs TWO f32 scratch buffers
// (caller-supplied, packed back-to-back in `workspace`):
//
// - `dq_accum` — shape [B, seqlen_q_rounded, H, head_size_rounded] f32
// (dense) or [total_q + 128*B, H, head_size_rounded] f32
// (varlen).
// - `dsoftmax_d` — shape [B, H, seqlen_q_rounded] f32 (dense) or
// [H, total_q + 128*B] f32 (varlen).
//
// where `seqlen_q_rounded = round_up(sq, 128)` and
// `head_size_rounded = round_up(d, d <= 128 ? 32 : 64)`.
//
// The companion `..._backward_workspace_size` symbols return the total
// byte size; the launcher zeros the scratch via `cudaMemsetAsync` before
// the launch (FA2's BW kernels read dq_accum before final convert).
//
// **LSE input contract**: BW's `lse` arg must be the **f32** LSE
// written by the FA2 FW pass (`softmax_lse` arg of the FW `..._run_v2`
// symbols). Reusing baracuda's bespoke FlashSdpa LSE (typed T) is
// INVALID — FA2 always stores LSE in f32.
//
// **Varlen layout**:
// - Q / O : packed [total_q, H, D] (row_stride = D * H,
// head_stride = D, batch_stride = 0)
// - K / V : packed [total_k, H_k, D] (row_stride = D * H_k)
// - cu_seqlens_q : i32[batch + 1] — cumulative
// (cu_seqlens_q[0] = 0, cu_seqlens_q[batch] = total_q)
// - cu_seqlens_k : i32[batch + 1] — same convention
// - varlen LSE : f32 [H, total_q + 128 * batch] (unpadded format).
#[cfg(feature = "fa2")]
unsafe extern "C" {
/// FA2 backward, f16. Computes dQ, dK, dV given FW-saved O + LSE (f32)
/// and upstream gradient dO. ALiBi / sliding window / softcap plumbed.
/// This is the FA2 backward trailblazer — its contract carries over
/// to the bf16 sibling.
///
/// `workspace_bytes` must be at least
/// `baracuda_kernels_fa2_sdpa_backward_workspace_size(batch, num_heads, seq_q, head_dim)`
/// bytes. The launcher zero-fills the workspace internally.
///
/// # LSE saved-tensor contract (Phase 63)
///
/// `lse` MUST be the exact same f32 buffer written by the
/// corresponding [`baracuda_kernels_fa2_sdpa_f16_run`] (or
/// `..._run_v2`) forward call on the same `(q, k, v)`. Pre-allocate
/// via [`baracuda_kernels_fa2_sdpa_lse_size`]`(batch, num_heads, seq_q)`
/// f32 elements. See the FW trailblazer doc for the full
/// FW→saved-LSE→BW handoff pattern.
///
/// Passing a different LSE buffer (e.g. recomputed, or from a
/// different FW pass) produces silently-wrong gradients — the BW
/// kernel uses LSE as the numerically-stable softmax normalizer
/// and trusts the value.
///
/// # Supported head_dim (Phase 63 — surfaced from `flash_sdpa_backward.rs`)
///
/// FA2 BW supports head_dim ∈ {32, 64, 96, 128, 192, 256}. hd160 /
/// hd224 / hd512 BW are fundamentally not supported by FA2's BW
/// algorithm (kBlockKSmem=32 atom_layout + kBlockM≥64 static_assert
/// constraints; see `crates/baracuda-kernels/src/attention/flash_sdpa_backward.rs`
/// `FA2_BW_SUPPORTED_HEAD_DIMS` and Phase 60 VENDOR.md for the
/// experimental confirmation). Callers needing BW at hd160 / hd224
/// / hd512 must fall back to the bespoke `SdpaBackwardPlan` (Phase
/// 6 / Milestone 6.6) which supports d_k ≤ 128. This BW
/// head_dim limit matches Fuel's Vulkan FA backward cap.
pub fn baracuda_kernels_fa2_sdpa_backward_f16_run(
batch: i32,
num_heads: i32,
num_heads_k: i32,
seq_q: i32,
seq_k: i32,
head_dim: i32,
softmax_scale: f32,
is_causal: i32,
alibi_slopes_ptr: *const c_void,
alibi_batch_stride: i32,
window_size_left: i32,
window_size_right: i32,
softcap: f32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
o: *const c_void,
dout: *const c_void,
lse: *const c_void,
dq: *mut c_void,
dk: *mut c_void,
dv: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// FA2 backward, bf16. See `..._backward_f16_run`.
pub fn baracuda_kernels_fa2_sdpa_backward_bf16_run(
batch: i32,
num_heads: i32,
num_heads_k: i32,
seq_q: i32,
seq_k: i32,
head_dim: i32,
softmax_scale: f32,
is_causal: i32,
alibi_slopes_ptr: *const c_void,
alibi_batch_stride: i32,
window_size_left: i32,
window_size_right: i32,
softcap: f32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
o: *const c_void,
dout: *const c_void,
lse: *const c_void,
dq: *mut c_void,
dk: *mut c_void,
dv: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// FA2 BW host-side can-implement, f16.
pub fn baracuda_kernels_fa2_sdpa_backward_f16_can_implement(
batch: i32,
num_heads: i32,
num_heads_k: i32,
seq_q: i32,
seq_k: i32,
head_dim: i32,
is_causal: i32,
window_size_left: i32,
window_size_right: i32,
softcap: f32,
) -> i32;
/// FA2 BW host-side can-implement, bf16.
pub fn baracuda_kernels_fa2_sdpa_backward_bf16_can_implement(
batch: i32,
num_heads: i32,
num_heads_k: i32,
seq_q: i32,
seq_k: i32,
head_dim: i32,
is_causal: i32,
window_size_left: i32,
window_size_right: i32,
softcap: f32,
) -> i32;
/// Required BW workspace size in bytes (dense path). Caller passes
/// this much memory in the `workspace` arg of `..._backward_<dt>_run`.
pub fn baracuda_kernels_fa2_sdpa_backward_workspace_size(
batch: i32,
num_heads: i32,
seq_q: i32,
head_dim: i32,
) -> usize;
/// FA2 varlen forward, f16. Packed Q/K/V/O across `batch` sequences.
/// Writes `out` (packed [total_q, H, D]) and `softmax_lse` (f32
/// [H, total_q + 128 * batch]).
pub fn baracuda_kernels_fa2_sdpa_varlen_f16_run(
batch: i32,
num_heads: i32,
num_heads_k: i32,
max_seqlen_q: i32,
max_seqlen_k: i32,
total_q: i32,
total_k: i32,
head_dim: i32,
softmax_scale: f32,
is_causal: i32,
alibi_slopes_ptr: *const c_void,
alibi_batch_stride: i32,
window_size_left: i32,
window_size_right: i32,
softcap: f32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
cu_seqlens_q: *const i32,
cu_seqlens_k: *const i32,
out: *mut c_void,
softmax_lse: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// FA2 varlen forward, bf16.
pub fn baracuda_kernels_fa2_sdpa_varlen_bf16_run(
batch: i32,
num_heads: i32,
num_heads_k: i32,
max_seqlen_q: i32,
max_seqlen_k: i32,
total_q: i32,
total_k: i32,
head_dim: i32,
softmax_scale: f32,
is_causal: i32,
alibi_slopes_ptr: *const c_void,
alibi_batch_stride: i32,
window_size_left: i32,
window_size_right: i32,
softcap: f32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
cu_seqlens_q: *const i32,
cu_seqlens_k: *const i32,
out: *mut c_void,
softmax_lse: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// FA2 varlen FW can-implement, f16.
pub fn baracuda_kernels_fa2_sdpa_varlen_f16_can_implement(
batch: i32,
num_heads: i32,
num_heads_k: i32,
max_seqlen_q: i32,
max_seqlen_k: i32,
total_q: i32,
total_k: i32,
head_dim: i32,
is_causal: i32,
window_size_left: i32,
window_size_right: i32,
softcap: f32,
) -> i32;
/// FA2 varlen FW can-implement, bf16.
pub fn baracuda_kernels_fa2_sdpa_varlen_bf16_can_implement(
batch: i32,
num_heads: i32,
num_heads_k: i32,
max_seqlen_q: i32,
max_seqlen_k: i32,
total_q: i32,
total_k: i32,
head_dim: i32,
is_causal: i32,
window_size_left: i32,
window_size_right: i32,
softcap: f32,
) -> i32;
/// Varlen LSE size in **f32 elements**: `num_heads * (total_q + 128 * batch)`.
/// Caller multiplies by 4 for bytes.
pub fn baracuda_kernels_fa2_sdpa_varlen_lse_size(
batch: i32,
num_heads: i32,
total_q: i32,
) -> usize;
/// FA2 varlen backward, f16. Same packed layout as varlen FW.
/// Workspace size: `..._varlen_backward_workspace_size(...)`.
pub fn baracuda_kernels_fa2_sdpa_varlen_backward_f16_run(
batch: i32,
num_heads: i32,
num_heads_k: i32,
max_seqlen_q: i32,
max_seqlen_k: i32,
total_q: i32,
total_k: i32,
head_dim: i32,
softmax_scale: f32,
is_causal: i32,
alibi_slopes_ptr: *const c_void,
alibi_batch_stride: i32,
window_size_left: i32,
window_size_right: i32,
softcap: f32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
o: *const c_void,
dout: *const c_void,
lse: *const c_void,
cu_seqlens_q: *const i32,
cu_seqlens_k: *const i32,
dq: *mut c_void,
dk: *mut c_void,
dv: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// FA2 varlen backward, bf16.
pub fn baracuda_kernels_fa2_sdpa_varlen_backward_bf16_run(
batch: i32,
num_heads: i32,
num_heads_k: i32,
max_seqlen_q: i32,
max_seqlen_k: i32,
total_q: i32,
total_k: i32,
head_dim: i32,
softmax_scale: f32,
is_causal: i32,
alibi_slopes_ptr: *const c_void,
alibi_batch_stride: i32,
window_size_left: i32,
window_size_right: i32,
softcap: f32,
q: *const c_void,
k: *const c_void,
v: *const c_void,
o: *const c_void,
dout: *const c_void,
lse: *const c_void,
cu_seqlens_q: *const i32,
cu_seqlens_k: *const i32,
dq: *mut c_void,
dk: *mut c_void,
dv: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// FA2 varlen BW can-implement, f16.
pub fn baracuda_kernels_fa2_sdpa_varlen_backward_f16_can_implement(
batch: i32,
num_heads: i32,
num_heads_k: i32,
max_seqlen_q: i32,
max_seqlen_k: i32,
total_q: i32,
total_k: i32,
head_dim: i32,
is_causal: i32,
window_size_left: i32,
window_size_right: i32,
softcap: f32,
) -> i32;
/// FA2 varlen BW can-implement, bf16.
pub fn baracuda_kernels_fa2_sdpa_varlen_backward_bf16_can_implement(
batch: i32,
num_heads: i32,
num_heads_k: i32,
max_seqlen_q: i32,
max_seqlen_k: i32,
total_q: i32,
total_k: i32,
head_dim: i32,
is_causal: i32,
window_size_left: i32,
window_size_right: i32,
softcap: f32,
) -> i32;
/// Required varlen BW workspace size in bytes. Same layout as dense
/// BW workspace (dq_accum + dsoftmax_sum back-to-back), sized off
/// `total_q + 128 * batch` rows.
pub fn baracuda_kernels_fa2_sdpa_varlen_backward_workspace_size(
batch: i32,
num_heads: i32,
max_seqlen_q: i32,
total_q: i32,
head_dim: i32,
) -> usize;
}
// =============================================================================
// Phase 72 strided-sibling FFI exports for the normalizer + shape ops.
//
// Each `_strided_run` / `_strided_can_implement` symbol has the same signature
// as its non-strided sibling and routes to the same underlying CUDA launcher
// — the existing `_run` already accepts stride arrays and the C kernel
// honors them. The strided sibling exists so callers building explicit
// dispatch tables can pick the strided path by name (matches the
// Phase 14/18 convention for binary / unary-param ops).
//
// Covers: rms_norm (FW + BW), layer_norm (FW + BW), softmax (FW + BW),
// log_softmax (FW + BW), flip, roll, permute. 4 dtypes (f32/f16/bf16/f64)
// per family. 88 new FFI symbols total (44 `_strided_run` + 44
// `_strided_can_implement`).
// =============================================================================
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
// ---- RMSNorm FW strided siblings ----
/// RMSNorm FW strided sibling, f32. Same contract as
/// `baracuda_kernels_rms_norm_f32_run`; identical underlying launcher.
pub fn baracuda_kernels_rms_norm_f32_strided_run(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_rms: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
y: *mut c_void,
rms_out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `rms_norm_f32_strided_can_implement` companion.
pub fn baracuda_kernels_rms_norm_f32_strided_can_implement(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_rms: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
y: *const c_void,
rms_out: *const c_void,
) -> i32;
/// RMSNorm FW strided sibling, f16. See `rms_norm_f32_strided_run`.
pub fn baracuda_kernels_rms_norm_f16_strided_run(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_rms: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
y: *mut c_void,
rms_out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `rms_norm_f16_strided_can_implement` companion.
pub fn baracuda_kernels_rms_norm_f16_strided_can_implement(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_rms: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
y: *const c_void,
rms_out: *const c_void,
) -> i32;
/// RMSNorm FW strided sibling, bf16. See `rms_norm_f32_strided_run`.
pub fn baracuda_kernels_rms_norm_bf16_strided_run(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_rms: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
y: *mut c_void,
rms_out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `rms_norm_bf16_strided_can_implement` companion.
pub fn baracuda_kernels_rms_norm_bf16_strided_can_implement(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_rms: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
y: *const c_void,
rms_out: *const c_void,
) -> i32;
/// RMSNorm FW strided sibling, f64. See `rms_norm_f32_strided_run`.
pub fn baracuda_kernels_rms_norm_f64_strided_run(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_rms: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
y: *mut c_void,
rms_out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `rms_norm_f64_strided_can_implement` companion.
pub fn baracuda_kernels_rms_norm_f64_strided_can_implement(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_rms: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
y: *const c_void,
rms_out: *const c_void,
) -> i32;
}
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
// ---- RMSNorm BW strided siblings ----
/// RMSNorm BW strided sibling, f32. Same contract as
/// `baracuda_kernels_rms_norm_backward_f32_run`; identical underlying launcher.
pub fn baracuda_kernels_rms_norm_backward_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_rms: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
rms: *const c_void,
dx: *mut c_void,
dgamma: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `rms_norm_backward_f32_strided_can_implement` companion.
pub fn baracuda_kernels_rms_norm_backward_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_rms: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
rms: *const c_void,
dx: *const c_void,
dgamma: *const c_void,
) -> i32;
/// RMSNorm BW strided sibling, f16.
pub fn baracuda_kernels_rms_norm_backward_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_rms: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
rms: *const c_void,
dx: *mut c_void,
dgamma: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `rms_norm_backward_f16_strided_can_implement` companion.
pub fn baracuda_kernels_rms_norm_backward_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_rms: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
rms: *const c_void,
dx: *const c_void,
dgamma: *const c_void,
) -> i32;
/// RMSNorm BW strided sibling, bf16.
pub fn baracuda_kernels_rms_norm_backward_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_rms: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
rms: *const c_void,
dx: *mut c_void,
dgamma: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `rms_norm_backward_bf16_strided_can_implement` companion.
pub fn baracuda_kernels_rms_norm_backward_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_rms: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
rms: *const c_void,
dx: *const c_void,
dgamma: *const c_void,
) -> i32;
/// RMSNorm BW strided sibling, f64.
pub fn baracuda_kernels_rms_norm_backward_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_rms: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
rms: *const c_void,
dx: *mut c_void,
dgamma: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `rms_norm_backward_f64_strided_can_implement` companion.
pub fn baracuda_kernels_rms_norm_backward_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_rms: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
rms: *const c_void,
dx: *const c_void,
dgamma: *const c_void,
) -> i32;
}
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
// ---- LayerNorm FW strided siblings ----
/// LayerNorm FW strided sibling, f32. Same contract as
/// `baracuda_kernels_layer_norm_f32_run`; identical underlying launcher.
pub fn baracuda_kernels_layer_norm_f32_strided_run(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_save: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *mut c_void,
mean_out: *mut c_void,
inv_std_out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `layer_norm_f32_strided_can_implement` companion.
pub fn baracuda_kernels_layer_norm_f32_strided_can_implement(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_save: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *const c_void,
mean_out: *const c_void,
inv_std_out: *const c_void,
) -> i32;
/// LayerNorm FW strided sibling, f16.
pub fn baracuda_kernels_layer_norm_f16_strided_run(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_save: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *mut c_void,
mean_out: *mut c_void,
inv_std_out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `layer_norm_f16_strided_can_implement` companion.
pub fn baracuda_kernels_layer_norm_f16_strided_can_implement(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_save: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *const c_void,
mean_out: *const c_void,
inv_std_out: *const c_void,
) -> i32;
/// LayerNorm FW strided sibling, bf16.
pub fn baracuda_kernels_layer_norm_bf16_strided_run(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_save: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *mut c_void,
mean_out: *mut c_void,
inv_std_out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `layer_norm_bf16_strided_can_implement` companion.
pub fn baracuda_kernels_layer_norm_bf16_strided_can_implement(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_save: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *const c_void,
mean_out: *const c_void,
inv_std_out: *const c_void,
) -> i32;
/// LayerNorm FW strided sibling, f64.
pub fn baracuda_kernels_layer_norm_f64_strided_run(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_save: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *mut c_void,
mean_out: *mut c_void,
inv_std_out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `layer_norm_f64_strided_can_implement` companion.
pub fn baracuda_kernels_layer_norm_f64_strided_can_implement(
eps: f32,
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
stride_save: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
x: *const c_void,
gamma: *const c_void,
beta: *const c_void,
y: *const c_void,
mean_out: *const c_void,
inv_std_out: *const c_void,
) -> i32;
}
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
// ---- LayerNorm BW strided siblings ----
/// LayerNorm BW strided sibling, f32. Same contract as
/// `baracuda_kernels_layer_norm_backward_f32_run`; identical launcher.
pub fn baracuda_kernels_layer_norm_backward_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_save: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
mean_in: *const c_void,
inv_std_in: *const c_void,
dx: *mut c_void,
dgamma: *mut c_void,
dbeta: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `layer_norm_backward_f32_strided_can_implement` companion.
pub fn baracuda_kernels_layer_norm_backward_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_save: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
mean_in: *const c_void,
inv_std_in: *const c_void,
dx: *const c_void,
dgamma: *const c_void,
dbeta: *const c_void,
) -> i32;
/// LayerNorm BW strided sibling, f16.
pub fn baracuda_kernels_layer_norm_backward_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_save: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
mean_in: *const c_void,
inv_std_in: *const c_void,
dx: *mut c_void,
dgamma: *mut c_void,
dbeta: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `layer_norm_backward_f16_strided_can_implement` companion.
pub fn baracuda_kernels_layer_norm_backward_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_save: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
mean_in: *const c_void,
inv_std_in: *const c_void,
dx: *const c_void,
dgamma: *const c_void,
dbeta: *const c_void,
) -> i32;
/// LayerNorm BW strided sibling, bf16.
pub fn baracuda_kernels_layer_norm_backward_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_save: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
mean_in: *const c_void,
inv_std_in: *const c_void,
dx: *mut c_void,
dgamma: *mut c_void,
dbeta: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `layer_norm_backward_bf16_strided_can_implement` companion.
pub fn baracuda_kernels_layer_norm_backward_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_save: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
mean_in: *const c_void,
inv_std_in: *const c_void,
dx: *const c_void,
dgamma: *const c_void,
dbeta: *const c_void,
) -> i32;
/// LayerNorm BW strided sibling, f64.
pub fn baracuda_kernels_layer_norm_backward_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_save: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
mean_in: *const c_void,
inv_std_in: *const c_void,
dx: *mut c_void,
dgamma: *mut c_void,
dbeta: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `layer_norm_backward_f64_strided_can_implement` companion.
pub fn baracuda_kernels_layer_norm_backward_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_x: *const i64,
stride_save: *const i64,
stride_dx: *const i64,
norm_axes_mask: i32,
norm_total_extent: i32,
dy: *const c_void,
x: *const c_void,
gamma: *const c_void,
mean_in: *const c_void,
inv_std_in: *const c_void,
dx: *const c_void,
dgamma: *const c_void,
dbeta: *const c_void,
) -> i32;
}
// Softmax + LogSoftmax FW/BW signatures all share this shape.
// Each macro emits 4 dtype rows below.
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
// ---- Softmax FW strided siblings ----
/// Softmax FW strided sibling, f32. Same contract as
/// `baracuda_kernels_softmax_f32_run`; identical underlying launcher.
pub fn baracuda_kernels_softmax_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `softmax_f32_strided_can_implement` companion.
pub fn baracuda_kernels_softmax_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Softmax FW strided sibling, f16.
pub fn baracuda_kernels_softmax_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `softmax_f16_strided_can_implement` companion.
pub fn baracuda_kernels_softmax_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Softmax FW strided sibling, bf16.
pub fn baracuda_kernels_softmax_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `softmax_bf16_strided_can_implement` companion.
pub fn baracuda_kernels_softmax_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Softmax FW strided sibling, f64.
pub fn baracuda_kernels_softmax_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `softmax_f64_strided_can_implement` companion.
pub fn baracuda_kernels_softmax_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
// ---- Softmax BW strided siblings ----
/// Softmax BW strided sibling, f32.
pub fn baracuda_kernels_softmax_backward_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `softmax_backward_f32_strided_can_implement` companion.
pub fn baracuda_kernels_softmax_backward_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// Softmax BW strided sibling, f16.
pub fn baracuda_kernels_softmax_backward_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `softmax_backward_f16_strided_can_implement` companion.
pub fn baracuda_kernels_softmax_backward_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// Softmax BW strided sibling, bf16.
pub fn baracuda_kernels_softmax_backward_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `softmax_backward_bf16_strided_can_implement` companion.
pub fn baracuda_kernels_softmax_backward_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// Softmax BW strided sibling, f64.
pub fn baracuda_kernels_softmax_backward_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `softmax_backward_f64_strided_can_implement` companion.
pub fn baracuda_kernels_softmax_backward_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
}
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
// ---- LogSoftmax FW strided siblings ----
/// LogSoftmax FW strided sibling, f32. ABI identical to softmax FW.
pub fn baracuda_kernels_log_softmax_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `log_softmax_f32_strided_can_implement` companion.
pub fn baracuda_kernels_log_softmax_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// LogSoftmax FW strided sibling, f16.
pub fn baracuda_kernels_log_softmax_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `log_softmax_f16_strided_can_implement` companion.
pub fn baracuda_kernels_log_softmax_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// LogSoftmax FW strided sibling, bf16.
pub fn baracuda_kernels_log_softmax_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `log_softmax_bf16_strided_can_implement` companion.
pub fn baracuda_kernels_log_softmax_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// LogSoftmax FW strided sibling, f64.
pub fn baracuda_kernels_log_softmax_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `log_softmax_f64_strided_can_implement` companion.
pub fn baracuda_kernels_log_softmax_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_x: *const i64,
stride_y: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_x: i64,
softmax_stride_y: i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
// ---- LogSoftmax BW strided siblings ----
/// LogSoftmax BW strided sibling, f32. ABI identical to softmax BW.
pub fn baracuda_kernels_log_softmax_backward_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `log_softmax_backward_f32_strided_can_implement` companion.
pub fn baracuda_kernels_log_softmax_backward_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// LogSoftmax BW strided sibling, f16.
pub fn baracuda_kernels_log_softmax_backward_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `log_softmax_backward_f16_strided_can_implement` companion.
pub fn baracuda_kernels_log_softmax_backward_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// LogSoftmax BW strided sibling, bf16.
pub fn baracuda_kernels_log_softmax_backward_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `log_softmax_backward_bf16_strided_can_implement` companion.
pub fn baracuda_kernels_log_softmax_backward_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
/// LogSoftmax BW strided sibling, f64.
pub fn baracuda_kernels_log_softmax_backward_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `log_softmax_backward_f64_strided_can_implement` companion.
pub fn baracuda_kernels_log_softmax_backward_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
stride_dy: *const i64,
stride_y: *const i64,
stride_dx: *const i64,
softmax_axis: i32,
softmax_extent: i32,
softmax_stride_dy: i64,
softmax_stride_y: i64,
dy: *const c_void,
y: *const c_void,
dx: *const c_void,
) -> i32;
}
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
// ---- Flip strided siblings ----
/// Flip strided sibling, f32.
pub fn baracuda_kernels_flip_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
flip_axes: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `flip_f32_strided_can_implement` companion.
pub fn baracuda_kernels_flip_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
flip_axes: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Flip strided sibling, f16.
pub fn baracuda_kernels_flip_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
flip_axes: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `flip_f16_strided_can_implement` companion.
pub fn baracuda_kernels_flip_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
flip_axes: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Flip strided sibling, bf16.
pub fn baracuda_kernels_flip_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
flip_axes: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `flip_bf16_strided_can_implement` companion.
pub fn baracuda_kernels_flip_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
flip_axes: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Flip strided sibling, f64.
pub fn baracuda_kernels_flip_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
flip_axes: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `flip_f64_strided_can_implement` companion.
pub fn baracuda_kernels_flip_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
flip_axes: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
// ---- Roll strided siblings ----
/// Roll strided sibling, f32.
pub fn baracuda_kernels_roll_f32_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
shifts: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `roll_f32_strided_can_implement` companion.
pub fn baracuda_kernels_roll_f32_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
shifts: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Roll strided sibling, f16.
pub fn baracuda_kernels_roll_f16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
shifts: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `roll_f16_strided_can_implement` companion.
pub fn baracuda_kernels_roll_f16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
shifts: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Roll strided sibling, bf16.
pub fn baracuda_kernels_roll_bf16_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
shifts: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `roll_bf16_strided_can_implement` companion.
pub fn baracuda_kernels_roll_bf16_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
shifts: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Roll strided sibling, f64.
pub fn baracuda_kernels_roll_f64_strided_run(
numel: i64,
rank: i32,
shape: *const i32,
shifts: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `roll_f64_strided_can_implement` companion.
pub fn baracuda_kernels_roll_f64_strided_can_implement(
numel: i64,
rank: i32,
shape: *const i32,
shifts: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}
#[cfg(any(feature = "sm80", feature = "sm89", feature = "sm90a"))]
unsafe extern "C" {
// ---- Permute strided siblings ----
/// Permute strided sibling, f32.
pub fn baracuda_kernels_permute_f32_strided_run(
input_numel: i64,
rank: i32,
input_shape: *const i32,
dims: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `permute_f32_strided_can_implement` companion.
pub fn baracuda_kernels_permute_f32_strided_can_implement(
input_numel: i64,
rank: i32,
input_shape: *const i32,
dims: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Permute strided sibling, f16.
pub fn baracuda_kernels_permute_f16_strided_run(
input_numel: i64,
rank: i32,
input_shape: *const i32,
dims: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `permute_f16_strided_can_implement` companion.
pub fn baracuda_kernels_permute_f16_strided_can_implement(
input_numel: i64,
rank: i32,
input_shape: *const i32,
dims: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Permute strided sibling, bf16.
pub fn baracuda_kernels_permute_bf16_strided_run(
input_numel: i64,
rank: i32,
input_shape: *const i32,
dims: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `permute_bf16_strided_can_implement` companion.
pub fn baracuda_kernels_permute_bf16_strided_can_implement(
input_numel: i64,
rank: i32,
input_shape: *const i32,
dims: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
/// Permute strided sibling, f64.
pub fn baracuda_kernels_permute_f64_strided_run(
input_numel: i64,
rank: i32,
input_shape: *const i32,
dims: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32;
/// `permute_f64_strided_can_implement` companion.
pub fn baracuda_kernels_permute_f64_strided_can_implement(
input_numel: i64,
rank: i32,
input_shape: *const i32,
dims: *const i32,
stride_x: *const i64,
stride_y: *const i64,
x: *const c_void,
y: *const c_void,
) -> i32;
}