#[cfg(feature = "cuda")]
pub fn selective_scan_cuda(
x: &candle_core::Tensor,
dt: &candle_core::Tensor,
a: &candle_core::Tensor,
b: &candle_core::Tensor,
c: &candle_core::Tensor,
d: &candle_core::Tensor,
dt_bias: &candle_core::Tensor,
state: &mut candle_core::Tensor,
dt_min: f32,
dt_max: f32,
) -> candle_core::Result<candle_core::Tensor> {
use candle_core as candle;
use candle_core::cuda_backend::cudarc::driver::DevicePtr;
let x = x.contiguous()?;
let dt = dt.contiguous()?;
let a = a.contiguous()?;
let b = b.contiguous()?;
let c = c.contiguous()?;
let d = d.contiguous()?;
let dt_bias = dt_bias.contiguous()?;
let (batch_size, seq_len, n_heads, head_dim) = x.dims4()?;
let (_, d_state) = {
let dims = b.dims4()?;
(dims.2, dims.3)
};
let x_flat = x.reshape((batch_size, seq_len, n_heads * head_dim))?;
let b_flat = b.reshape((batch_size, seq_len, n_heads * d_state))?;
let c_flat = c.reshape((batch_size, seq_len, n_heads * d_state))?;
let dev = x_flat.device().as_cuda_device()?;
macro_rules! get_ptr {
($tensor:expr) => {{
let (s, l) = $tensor.storage_and_layout();
let s = match &*s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("selective_scan_cuda: tensor must be on CUDA"),
};
let ptr = s.slice(l.start_offset()..).device_ptr(s.stream()).0 as *const f32;
ptr
}};
}
let x_ptr = get_ptr!(x_flat);
let dt_ptr = get_ptr!(dt);
let a_ptr = get_ptr!(a);
let b_ptr = get_ptr!(b_flat);
let c_ptr = get_ptr!(c_flat);
let d_ptr = get_ptr!(d);
let dt_bias_ptr = get_ptr!(dt_bias);
let (state_s, state_l) = state.storage_and_layout();
let state_s = match &*state_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("selective_scan_cuda: state must be on CUDA"),
};
let state_ptr = {
let ptr = state_s
.slice(state_l.start_offset()..)
.device_ptr(state_s.stream())
.0 as *mut f32;
ptr
};
let _ = state_s;
let _ = state_l;
let y_elems = batch_size * seq_len * n_heads * head_dim;
let y_buf = unsafe { dev.alloc::<f32>(y_elems) }?;
let stream = dev.cuda_stream().cu_stream() as i64;
#[allow(clippy::cast_possible_truncation)]
unsafe {
crate::cuda::ffi::selective_scan_cuda(
x_ptr,
dt_ptr,
a_ptr,
b_ptr,
c_ptr,
d_ptr,
dt_bias_ptr,
state_ptr,
{
let p = y_buf.device_ptr(y_buf.stream()).0 as *mut f32;
p
},
batch_size as i32,
n_heads as i32,
head_dim as i32,
d_state as i32,
seq_len as i32,
dt_min,
dt_max,
stream,
);
}
let y_storage = candle::CudaStorage::wrap_cuda_slice(y_buf, dev.clone());
let y = candle_core::Tensor::from((
candle::Storage::Cuda(y_storage),
(batch_size, seq_len, n_heads, head_dim),
));
Ok(y)
}
#[cfg(not(feature = "cuda"))]
#[allow(clippy::too_many_arguments)]
pub fn selective_scan_cuda(
_x: &candle_core::Tensor,
_dt: &candle_core::Tensor,
_a: &candle_core::Tensor,
_b: &candle_core::Tensor,
_c: &candle_core::Tensor,
_d: &candle_core::Tensor,
_dt_bias: &candle_core::Tensor,
_state: &mut candle_core::Tensor,
_dt_min: f32,
_dt_max: f32,
) -> candle_core::Result<candle_core::Tensor> {
candle_core::bail!("selective_scan_cuda requires the cuda feature")
}