#![allow(dead_code)]
use candle_core::backend::BackendDevice;
use candle_core::cuda_backend::CudaDType;
use candle_core::{CudaDevice, CudaStorage, DeviceLocation, Error, Layout, Result};
use cudarc::driver::CudaSlice;
use kaio::prelude::{GpuBuffer, KaioDevice, KaioError};
pub(crate) fn slice_ref_from_storage<T: CudaDType>(storage: &CudaStorage) -> Result<&CudaSlice<T>> {
storage.as_cuda_slice::<T>()
}
pub(crate) fn storage_from_slice<T: CudaDType>(
slice: CudaSlice<T>,
device: CudaDevice,
) -> CudaStorage {
T::wrap_cuda_slice(slice, device)
}
pub(crate) fn buffer_ref_from_slice_readonly<T>(slice: &CudaSlice<T>) -> &GpuBuffer<T> {
unsafe { &*(slice as *const CudaSlice<T> as *const GpuBuffer<T>) }
}
pub(crate) fn reinterpret_u8_slice_as_i8(slice: &CudaSlice<u8>) -> &CudaSlice<i8> {
unsafe { &*(slice as *const CudaSlice<u8> as *const CudaSlice<i8>) }
}
fn candle_ordinal(dev: &CudaDevice) -> Result<usize> {
match dev.location() {
DeviceLocation::Cuda { gpu_id } => Ok(gpu_id),
other => Err(Error::Msg(format!(
"kaio-candle: expected CUDA device, candle reports location {other:?}"
))),
}
}
pub(crate) fn ensure_ordinal_match(candle_dev: &CudaDevice, kaio_dev: &KaioDevice) -> Result<()> {
let candle_ord = candle_ordinal(candle_dev)?;
let kaio_ord = kaio_dev.ordinal();
if candle_ord != kaio_ord {
return Err(Error::Msg(format!(
"kaio-candle: input tensor is on CUDA ordinal {candle_ord}, \
but the Arc<KaioDevice> passed is ordinal {kaio_ord}. \
Construct a KaioDevice on the same ordinal as the candle Device."
)));
}
Ok(())
}
pub(crate) fn sync_before_launch(candle_dev: &CudaDevice, kaio_dev: &KaioDevice) -> Result<()> {
let candle_stream = candle_dev.cuda_stream();
let kaio_stream = kaio_dev.stream();
kaio_stream.join(&candle_stream).map_err(driver_err)
}
pub(crate) fn sync_after_launch(candle_dev: &CudaDevice, kaio_dev: &KaioDevice) -> Result<()> {
let kaio_stream = kaio_dev.stream();
let candle_stream = candle_dev.cuda_stream();
candle_stream.join(kaio_stream).map_err(driver_err)
}
pub(crate) fn kaio_err(e: KaioError) -> Error {
Error::Msg(format!("kaio: {e}"))
}
pub(crate) fn driver_err(e: cudarc::driver::DriverError) -> Error {
Error::Msg(format!("kaio-candle stream sync: {e}"))
}
fn ensure_rank2_contiguous_zero_offset_inner(
op_name: &str,
input_label: &str,
layout: &Layout,
) -> Result<(usize, usize)> {
let shape = layout.shape();
let dims = shape.dims();
if dims.len() != 2 {
return Err(Error::Msg(format!(
"kaio-candle::{op_name}: {input_label} must be rank-2; \
got rank-{rank} input of shape {shape:?}. \
For multi-head attention, reshape to rank-2 via \
`.reshape((seq, d))?` after flattening batch+heads.",
rank = dims.len()
)));
}
if !layout.is_contiguous() {
return Err(Error::Msg(format!(
"kaio-candle::{op_name}: {input_label} must be contiguous; \
got shape {shape:?} with strides {strides:?}. \
Call `.contiguous()?` first.",
strides = layout.stride()
)));
}
if layout.start_offset() != 0 {
return Err(Error::Msg(format!(
"kaio-candle::{op_name}: {input_label} must start at storage offset 0; \
got offset {off} (likely from a `.slice(..)` / `.narrow(..)` call). \
Call `.contiguous()?` to compact.",
off = layout.start_offset()
)));
}
Ok((dims[0], dims[1]))
}
pub(crate) fn ensure_rank2_contiguous_zero_offset(
op_name: &'static str,
input_index: usize,
layout: &Layout,
) -> Result<(usize, usize)> {
ensure_rank2_contiguous_zero_offset_inner(op_name, &format!("input #{input_index}"), layout)
}
pub(crate) fn ensure_rank2_contiguous_zero_offset_named(
op_name: &str,
param_name: &str,
layout: &Layout,
) -> Result<(usize, usize)> {
ensure_rank2_contiguous_zero_offset_inner(op_name, param_name, layout)
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::{Layout, Shape};
#[test]
fn rank2_contiguous_zero_offset_returns_dims() {
let layout = Layout::contiguous(Shape::from_dims(&[64, 128]));
let (rows, cols) =
ensure_rank2_contiguous_zero_offset("test", 0, &layout).expect("happy path");
assert_eq!(rows, 64);
assert_eq!(cols, 128);
}
#[test]
fn rank1_rejected_with_reshape_hint() {
let layout = Layout::contiguous(Shape::from_dims(&[128]));
let err = ensure_rank2_contiguous_zero_offset("test_op", 0, &layout)
.expect_err("rank-1 must fail");
let msg = format!("{err}");
assert!(
msg.contains("must be rank-2"),
"expected 'must be rank-2' in {msg}"
);
assert!(msg.contains("reshape"), "expected reshape hint in {msg}");
assert!(msg.contains("test_op"), "expected op name in {msg}");
assert!(msg.contains("input #0"), "expected input index in {msg}");
}
#[test]
fn rank3_rejected_with_reshape_hint() {
let layout = Layout::contiguous(Shape::from_dims(&[8, 64, 64]));
let err = ensure_rank2_contiguous_zero_offset("test_op", 1, &layout)
.expect_err("rank-3 must fail");
let msg = format!("{err}");
assert!(msg.contains("must be rank-2"));
assert!(msg.contains("input #1"));
}
#[test]
fn rank4_rejected() {
let layout = Layout::contiguous(Shape::from_dims(&[2, 8, 64, 64]));
assert!(ensure_rank2_contiguous_zero_offset("test", 0, &layout).is_err());
}
}