use crate::error::{CoreError, Result};
use candle_core::{DType, Device, Tensor};
#[derive(Debug, Clone)]
pub struct TensorBuffer {
pub bytes: Vec<u8>,
pub shape: Vec<usize>,
pub dtype: DType,
}
impl TensorBuffer {
pub fn new(bytes: Vec<u8>, shape: Vec<usize>, dtype: DType) -> Self {
Self {
bytes,
shape,
dtype,
}
}
pub fn numel(&self) -> usize {
self.shape.iter().product()
}
pub fn size_bytes(&self) -> usize {
self.bytes.len()
}
}
#[must_use]
pub fn has_cubecl_cuda_support() -> bool {
matches!(Device::cuda_if_available(0), Ok(Device::Cuda(_)))
}
pub fn candle_to_cubecl_handle(tensor: &Tensor) -> Result<TensorBuffer> {
if !matches!(tensor.device(), Device::Cuda(_)) {
return Err(CoreError::invalid_config(
"candle_to_cubecl_handle requires CUDA tensor",
));
}
let tensor = tensor.contiguous()?;
let shape = tensor.dims().to_vec();
let dtype = tensor.dtype();
let bytes = match dtype {
DType::F32 => {
let data: Vec<f32> = tensor.flatten_all()?.to_vec1()?;
data.iter().flat_map(|f| f.to_le_bytes()).collect()
}
DType::F16 => {
let data: Vec<half::f16> = tensor.flatten_all()?.to_vec1()?;
data.iter().flat_map(|f| f.to_le_bytes()).collect()
}
DType::BF16 => {
let data: Vec<half::bf16> = tensor.flatten_all()?.to_vec1()?;
data.iter().flat_map(|f| f.to_le_bytes()).collect()
}
_ => {
return Err(CoreError::invalid_config(format!(
"candle_to_cubecl_handle does not support dtype {dtype:?}"
)));
}
};
Ok(TensorBuffer::new(bytes, shape, dtype))
}
pub fn cubecl_to_candle_tensor(buffer: &TensorBuffer, device: &Device) -> Result<Tensor> {
if !matches!(device, Device::Cuda(_)) {
return Err(CoreError::invalid_config(
"cubecl_to_candle_tensor requires CUDA device",
));
}
let numel = buffer.numel();
let expected_bytes = numel * buffer.dtype.size_in_bytes();
if buffer.bytes.len() != expected_bytes {
return Err(CoreError::shape_mismatch(
vec![expected_bytes],
vec![buffer.bytes.len()],
));
}
let tensor = match buffer.dtype {
DType::F32 => {
let data: Vec<f32> = buffer
.bytes
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
Tensor::from_vec(data, buffer.shape.as_slice(), device)?
}
DType::F16 => {
let data: Vec<half::f16> = buffer
.bytes
.chunks_exact(2)
.map(|chunk| half::f16::from_le_bytes([chunk[0], chunk[1]]))
.collect();
Tensor::from_vec(data, buffer.shape.as_slice(), device)?
}
DType::BF16 => {
let data: Vec<half::bf16> = buffer
.bytes
.chunks_exact(2)
.map(|chunk| half::bf16::from_le_bytes([chunk[0], chunk[1]]))
.collect();
Tensor::from_vec(data, buffer.shape.as_slice(), device)?
}
_ => {
return Err(CoreError::invalid_config(format!(
"cubecl_to_candle_tensor does not support dtype {:?}",
buffer.dtype
)));
}
};
Ok(tensor)
}
pub fn allocate_output_buffer(shape: &[usize], dtype: DType) -> Result<TensorBuffer> {
let numel: usize = shape.iter().product();
let size_bytes = numel * dtype.size_in_bytes();
let bytes = vec![0u8; size_bytes];
Ok(TensorBuffer::new(bytes, shape.to_vec(), dtype))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tensor_buffer() {
let buffer = TensorBuffer::new(vec![0u8; 64], vec![2, 8], DType::F32);
assert_eq!(buffer.numel(), 16);
assert_eq!(buffer.size_bytes(), 64);
}
#[test]
fn test_allocate_output_buffer() {
let buffer = allocate_output_buffer(&[4, 8, 16], DType::F32).unwrap();
assert_eq!(buffer.numel(), 512);
assert_eq!(buffer.size_bytes(), 2048); }
}