use crate::error::{Result, UnslothError};
use candle_core::{DType, Device, Tensor};
const BITS_PER_U32: usize = 32;
const BITS_PER_U64: usize = 64;
#[must_use]
pub fn has_cubecl_cuda_support() -> bool {
#[cfg(feature = "cuda")]
{
matches!(Device::cuda_if_available(0), Ok(Device::Cuda(_)))
}
#[cfg(not(feature = "cuda"))]
{
false
}
}
pub fn candle_to_cubecl_handle(tensor: &Tensor) -> Result<(Vec<u8>, Vec<usize>, DType)> {
if !matches!(tensor.device(), Device::Cuda(_)) {
return Err(UnslothError::InvalidConfig(
"candle_to_cubecl_handle requires CUDA tensor".to_string(),
));
}
let tensor = tensor.contiguous()?;
let shape = tensor.dims().to_vec();
let dtype = tensor.dtype();
if dtype != DType::F32 {
return Err(UnslothError::InvalidConfig(format!(
"candle_to_cubecl_handle only supports f32, got {dtype:?}"
)));
}
let data: Vec<f32> = tensor.flatten_all()?.to_vec1()?;
let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
Ok((bytes, shape, dtype))
}
pub fn cubecl_to_candle_tensor(bytes: &[u8], shape: &[usize], device: &Device) -> Result<Tensor> {
if !matches!(device, Device::Cuda(_)) {
return Err(UnslothError::InvalidConfig(
"cubecl_to_candle_tensor requires CUDA device".to_string(),
));
}
let num_elements: usize = shape.iter().product();
let expected_bytes = num_elements * 4;
if bytes.len() != expected_bytes {
return Err(UnslothError::InvalidConfig(format!(
"Byte count mismatch: expected {} for shape {:?}, got {}",
expected_bytes,
shape,
bytes.len()
)));
}
let data: Vec<f32> = bytes
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
let tensor = Tensor::from_vec(data, shape, device)?;
Ok(tensor)
}
#[must_use]
pub fn allocate_output_buffer(num_elements: usize) -> Vec<u8> {
vec![0u8; num_elements * 4]
}
#[must_use]
pub fn u32_planes_to_cubecl_bytes(plane: &[u32]) -> Vec<u8> {
plane.iter().flat_map(|&word| word.to_le_bytes()).collect()
}
#[must_use]
pub fn ternary_tensor_to_cubecl_handles(
tensor: &crate::kernels::ternary::TernaryTensor,
) -> (Vec<u8>, Vec<u8>, Vec<u8>, (usize, usize), usize) {
let plus_bytes = u32_planes_to_cubecl_bytes(&tensor.plus_plane);
let minus_bytes = u32_planes_to_cubecl_bytes(&tensor.minus_plane);
let scales_bytes: Vec<u8> = tensor
.scales
.iter()
.flat_map(|&s| s.to_le_bytes())
.collect();
(
plus_bytes,
minus_bytes,
scales_bytes,
tensor.shape,
tensor.k_words,
)
}
#[must_use]
pub fn cubecl_bytes_to_u32_plane(bytes: &[u8]) -> Vec<u32> {
bytes
.chunks_exact(4)
.map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect()
}
#[must_use]
pub fn sparsity_metadata_to_cubecl_bytes(
metadata: &crate::kernels::ternary::SparsityMetadata,
) -> Vec<u8> {
metadata
.active_chunks
.iter()
.flat_map(|&word| word.to_le_bytes())
.collect()
}
#[must_use]
pub fn create_sparsity_bitmap_for_tensor(
tensor: &crate::kernels::ternary::TernaryTensor,
chunk_size: usize,
) -> Vec<u8> {
let (out_features, _in_features) = tensor.shape;
let k_words = tensor.k_words;
let words_per_chunk = chunk_size / BITS_PER_U32;
let num_chunks = (k_words + words_per_chunk - 1) / words_per_chunk;
let bitmap_words = (num_chunks + BITS_PER_U64 - 1) / BITS_PER_U64;
let mut bitmap = vec![0u64; out_features * bitmap_words];
for row in 0..out_features {
for chunk_idx in 0..num_chunks {
let word_start = row * k_words + chunk_idx * words_per_chunk;
let word_end = std::cmp::min(word_start + words_per_chunk, (row + 1) * k_words);
let mut is_active = false;
for word_idx in word_start..word_end {
if tensor.plus_plane[word_idx] != 0 || tensor.minus_plane[word_idx] != 0 {
is_active = true;
break;
}
}
if is_active {
let bitmap_idx = row * bitmap_words + chunk_idx / BITS_PER_U64;
let bit_idx = chunk_idx % BITS_PER_U64;
bitmap[bitmap_idx] |= 1u64 << bit_idx;
}
}
}
bitmap.iter().flat_map(|&word| word.to_le_bytes()).collect()
}
pub fn u32_tensor_to_cubecl_handle(tensor: &Tensor) -> Result<(Vec<u8>, Vec<usize>, DType)> {
if !matches!(tensor.device(), Device::Cuda(_)) {
return Err(UnslothError::InvalidConfig(
"u32_tensor_to_cubecl_handle requires CUDA tensor".to_string(),
));
}
let tensor = tensor.contiguous()?;
let shape = tensor.dims().to_vec();
let dtype = tensor.dtype();
if dtype != DType::U32 {
return Err(UnslothError::InvalidConfig(format!(
"u32_tensor_to_cubecl_handle only supports U32, got {:?}",
dtype
)));
}
let data: Vec<u32> = tensor.flatten_all()?.to_vec1()?;
let bytes: Vec<u8> = data.iter().flat_map(|u| u.to_le_bytes()).collect();
Ok((bytes, shape, dtype))
}
pub fn cubecl_to_u32_candle_tensor(
bytes: &[u8],
shape: &[usize],
device: &Device,
) -> Result<Tensor> {
if !matches!(device, Device::Cuda(_)) {
return Err(UnslothError::InvalidConfig(
"cubecl_to_u32_candle_tensor requires CUDA device".to_string(),
));
}
let num_elements: usize = shape.iter().product();
let expected_bytes = num_elements * 4;
if bytes.len() != expected_bytes {
return Err(UnslothError::InvalidConfig(format!(
"Byte count mismatch: expected {} for shape {:?}, got {}",
expected_bytes,
shape,
bytes.len()
)));
}
let data: Vec<u32> = bytes
.chunks_exact(4)
.map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
let tensor = Tensor::from_vec(data, shape, device)?;
Ok(tensor)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_has_cubecl_cuda_support() {
let _ = has_cubecl_cuda_support();
}
#[test]
fn test_allocate_output_buffer() {
let buffer = allocate_output_buffer(100);
assert_eq!(buffer.len(), 400); }
#[test]
fn test_candle_to_cubecl_cpu_error() {
let tensor = Tensor::zeros((2, 4), DType::F32, &Device::Cpu).unwrap();
let result = candle_to_cubecl_handle(&tensor);
assert!(result.is_err());
}
#[test]
fn test_u32_planes_to_bytes_roundtrip() {
let original: Vec<u32> = vec![0xDEADBEEF, 0xCAFEBABE, 0x12345678];
let bytes = u32_planes_to_cubecl_bytes(&original);
let recovered = cubecl_bytes_to_u32_plane(&bytes);
assert_eq!(original, recovered);
}
#[test]
fn test_ternary_tensor_to_cubecl_handles() {
use crate::kernels::ternary::TernaryTensor;
let shape = (4, 64); let k_words = 2;
let plus = vec![0xAAAAAAAAu32; 4 * k_words];
let minus = vec![0x55555555u32; 4 * k_words];
let scales = vec![1.5f32; 4];
let expected_plus = plus.clone();
let tensor = TernaryTensor::new(plus, minus, scales, shape);
let (plus_bytes, minus_bytes, scales_bytes, ret_shape, ret_k_words) =
ternary_tensor_to_cubecl_handles(&tensor);
assert_eq!(ret_shape, shape);
assert_eq!(ret_k_words, k_words);
assert_eq!(plus_bytes.len(), 4 * k_words * 4);
assert_eq!(minus_bytes.len(), 4 * k_words * 4);
assert_eq!(scales_bytes.len(), 4 * 4);
let recovered_plus = cubecl_bytes_to_u32_plane(&plus_bytes);
assert_eq!(expected_plus, recovered_plus);
}
#[test]
fn test_u32_tensor_cpu_error() {
let tensor = Tensor::zeros((4, 8), DType::U32, &Device::Cpu).unwrap();
let result = u32_tensor_to_cubecl_handle(&tensor);
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("CUDA"));
}
}
#[test]
fn test_u32_bytes_roundtrip() {
let original: Vec<u32> = vec![
0x12345678, 0xABCDEF01, 0xDEADBEEF, 0xCAFEBABE, 0xFFFFFFFF, 0x00000000,
];
let bytes = u32_planes_to_cubecl_bytes(&original);
let recovered = cubecl_bytes_to_u32_plane(&bytes);
assert_eq!(original, recovered);
assert_eq!(bytes[0..4], [0x78, 0x56, 0x34, 0x12]); }
#[test]
fn test_sparsity_bitmap_creation() {
use crate::kernels::ternary::TernaryTensor;
let shape = (4, 128); let k_words = 4;
let mut plus = vec![0u32; 4 * k_words];
for row in 0..4 {
plus[row * k_words] = 0xFFFFFFFF; plus[row * k_words + 1] = 0x0; plus[row * k_words + 2] = 0xFFFFFFFF; plus[row * k_words + 3] = 0x0; }
let minus = vec![0u32; 4 * k_words];
let scales = vec![1.0f32; 4];
let tensor = TernaryTensor::new(plus, minus, scales, shape);
let bitmap_bytes = create_sparsity_bitmap_for_tensor(&tensor, 64);
let bitmap_words_per_row = 1;
assert_eq!(bitmap_bytes.len(), 4 * bitmap_words_per_row * 8);
let bitmap: Vec<u64> = bitmap_bytes
.chunks_exact(8)
.map(|chunk| {
u64::from_le_bytes([
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
])
})
.collect();
for row in 0..4 {
let row_bitmap = bitmap[row];
assert_ne!(
row_bitmap & 0x1,
0,
"Chunk 0 should be active for row {}",
row
);
assert_ne!(
row_bitmap & 0x2,
0,
"Chunk 1 should be active for row {}",
row
);
}
}
#[test]
fn test_sparsity_bitmap_fully_sparse() {
use crate::kernels::ternary::TernaryTensor;
let shape = (2, 128);
let k_words = 4;
let plus = vec![0u32; 2 * k_words];
let minus = vec![0u32; 2 * k_words];
let scales = vec![1.0f32; 2];
let tensor = TernaryTensor::new(plus, minus, scales, shape);
let bitmap_bytes = create_sparsity_bitmap_for_tensor(&tensor, 64);
let bitmap: Vec<u64> = bitmap_bytes
.chunks_exact(8)
.map(|chunk| {
u64::from_le_bytes([
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
])
})
.collect();
for &word in &bitmap {
assert_eq!(
word, 0,
"Fully sparse tensor should have all chunks inactive"
);
}
}
#[test]
fn test_sparsity_bitmap_fully_dense() {
use crate::kernels::ternary::TernaryTensor;
let shape = (2, 128);
let k_words = 4;
let plus = vec![0xFFFFFFFFu32; 2 * k_words];
let minus = vec![0u32; 2 * k_words];
let scales = vec![1.0f32; 2];
let tensor = TernaryTensor::new(plus, minus, scales, shape);
let bitmap_bytes = create_sparsity_bitmap_for_tensor(&tensor, 64);
let bitmap: Vec<u64> = bitmap_bytes
.chunks_exact(8)
.map(|chunk| {
u64::from_le_bytes([
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
])
})
.collect();
let num_chunks = 2; for (row, &word) in bitmap.iter().enumerate() {
let expected = (1u64 << num_chunks) - 1; assert_eq!(
word, expected,
"Fully dense tensor should have all chunks active for row {}",
row
);
}
}
#[cfg(feature = "cuda")]
mod cuda_tests {
use super::*;
#[test]
fn test_roundtrip_conversion() {
if let Ok(device) = Device::cuda_if_available(0) {
if matches!(device, Device::Cuda(_)) {
let original = Tensor::randn(0.0f32, 1.0, (2, 4, 8, 64), &device).unwrap();
let (bytes, shape, _) = candle_to_cubecl_handle(&original).unwrap();
let recovered = cubecl_to_candle_tensor(&bytes, &shape, &device).unwrap();
assert_eq!(original.dims(), recovered.dims());
let orig_data: Vec<f32> = original.flatten_all().unwrap().to_vec1().unwrap();
let rec_data: Vec<f32> = recovered.flatten_all().unwrap().to_vec1().unwrap();
for (a, b) in orig_data.iter().zip(rec_data.iter()) {
assert!((a - b).abs() < 1e-6, "Values differ: {} vs {}", a, b);
}
}
}
}
#[test]
fn test_u32_tensor_roundtrip() {
if let Ok(device) = Device::cuda_if_available(0) {
if matches!(device, Device::Cuda(_)) {
let data: Vec<u32> = vec![
0x12345678, 0xABCDEF01, 0xDEADBEEF, 0xCAFEBABE, 0xFFFFFFFF, 0x00000000,
0xAAAAAAAA, 0x55555555,
];
let shape = (2, 4); let original = Tensor::from_vec(data.clone(), shape, &device).unwrap();
let (bytes, ret_shape, dtype) = u32_tensor_to_cubecl_handle(&original).unwrap();
assert_eq!(ret_shape, vec![2, 4]);
assert_eq!(dtype, DType::U32);
assert_eq!(bytes.len(), 8 * 4);
let recovered = cubecl_to_u32_candle_tensor(&bytes, &[2, 4], &device).unwrap();
assert_eq!(recovered.dims(), &[2, 4]);
let rec_data: Vec<u32> = recovered.flatten_all().unwrap().to_vec1().unwrap();
assert_eq!(data, rec_data);
}
}
}
#[test]
fn test_u32_tensor_large() {
if let Ok(device) = Device::cuda_if_available(0) {
if matches!(device, Device::Cuda(_)) {
let rows = 512;
let k_words = 128;
let total = rows * k_words;
let data: Vec<u32> = (0..total)
.map(|i| if i % 2 == 0 { 0xAAAAAAAA } else { 0x55555555 })
.collect();
let original =
Tensor::from_vec(data.clone(), (rows, k_words), &device).unwrap();
let (bytes, shape, _) = u32_tensor_to_cubecl_handle(&original).unwrap();
let recovered = cubecl_to_u32_candle_tensor(&bytes, &shape, &device).unwrap();
assert_eq!(recovered.dims(), &[rows, k_words]);
let rec_data: Vec<u32> = recovered.flatten_all().unwrap().to_vec1().unwrap();
assert_eq!(data, rec_data);
}
}
}
#[test]
fn test_u32_tensor_wrong_dtype() {
if let Ok(device) = Device::cuda_if_available(0) {
if matches!(device, Device::Cuda(_)) {
let tensor = Tensor::zeros((4, 8), DType::F32, &device).unwrap();
let result = u32_tensor_to_cubecl_handle(&tensor);
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("U32"));
}
}
}
}
}
}