use byteorder::{LittleEndian, ReadBytesExt};
use candle_core::{DType, Device, Result, Tensor, WithDType};
use float8::F8E4M3;
use half::{bf16, f16};
const UQFF_VERSION_MAJOR: u32 = 0;
const UQFF_VERSION_MINOR: u32 = 2;
const UQFF_VERSION_PATCH: u32 = 0;
pub(crate) const UQFF_VERSION: u32 =
(UQFF_VERSION_MAJOR << (8 * 2)) | (UQFF_VERSION_MINOR << 8) | UQFF_VERSION_PATCH;
pub const UQFF_QUANT_TYPE_OFFSET: usize = std::mem::size_of::<u32>();
pub(crate) fn version_is_compatible(version: u32) -> Result<()> {
let major = version >> (8 * 2);
let minor = (version >> 8) & 0xFF;
let patch = version & 0xFF;
if major != UQFF_VERSION_MAJOR {
candle_core::bail!("Major version of ISQ artifact file ({major}) does not match the implementation in this build ({UQFF_VERSION_MAJOR})");
}
if minor > UQFF_VERSION_MINOR {
candle_core::bail!("Minor version of ISQ artifact file ({major}.{minor}.{patch}) is newer than this build supports ({UQFF_VERSION_MAJOR}.{UQFF_VERSION_MINOR}.{UQFF_VERSION_PATCH}). Please update mistral.rs.");
}
Ok(())
}
pub(crate) fn write_dtype(dtype: DType, buffer: &mut Vec<u8>) {
let dtype: u32 = match dtype {
DType::U8 => 0,
DType::U32 => 1,
DType::I32 => 2,
DType::I64 => 3,
DType::F16 => 4,
DType::BF16 => 5,
DType::F32 => 6,
DType::F64 => 7,
DType::I16 => 8,
DType::F8E4M3 => 9,
DType::F6E2M3 => 10,
DType::F6E3M2 => 11,
DType::F4 => 12,
DType::F8E8M0 => 13,
other => panic!("Unsupported dtype for UQFF serialization: {other:?}"), };
buffer.extend(&dtype.to_le_bytes());
}
pub(crate) fn read_dtype<R: std::io::Read>(buffer: &mut R) -> Result<DType> {
let dtype = buffer.read_u32::<LittleEndian>()?;
let dtype = match dtype {
0 => DType::U8,
1 => DType::U32,
2 => DType::I32,
3 => DType::I64,
4 => DType::F16,
5 => DType::BF16,
6 => DType::F32,
7 => DType::F64,
8 => DType::I16,
9 => DType::F8E4M3,
10 => DType::F6E2M3,
11 => DType::F6E3M2,
12 => DType::F4,
13 => DType::F8E8M0,
_ => candle_core::bail!("unknown dtype for quantized tensor {dtype}"),
};
Ok(dtype)
}
pub(crate) fn serialize_tensor(buffer: &mut Vec<u8>, tensor: &Tensor) -> Result<()> {
let b_shape = tensor.dims();
let tensor = tensor.flatten_all()?;
let bias = match tensor.dtype() {
DType::U8 => data_to_bytes::<u8>(tensor.to_vec1()?),
DType::U32 => data_to_bytes::<u32>(tensor.to_vec1()?),
DType::I16 => data_to_bytes::<i16>(tensor.to_vec1()?),
DType::I32 => data_to_bytes::<i32>(tensor.to_vec1()?),
DType::I64 => data_to_bytes::<i64>(tensor.to_vec1()?),
DType::F16 => data_to_bytes::<half::f16>(tensor.to_vec1()?),
DType::BF16 => data_to_bytes::<half::bf16>(tensor.to_vec1()?),
DType::F32 => data_to_bytes::<f32>(tensor.to_vec1()?),
DType::F64 => data_to_bytes::<f64>(tensor.to_vec1()?),
DType::F8E4M3 => data_to_bytes::<F8E4M3>(tensor.to_vec1()?),
DType::F4 | DType::F6E3M2 | DType::F6E2M3 | DType::F8E8M0 => {
candle_core::bail!("f4/f6e3m2/f6e2m3/f8e8m0 tensors cannot be serialized.")
}
other => {
candle_core::bail!("Unsupported dtype for UQFF tensor serialization: {other:?}")
}
};
let data_len = bias.len();
if data_len > u32::MAX as usize {
candle_core::bail!(
"Tensor data too large for UQFF format: {} bytes exceeds u32::MAX",
data_len
);
}
buffer.extend(&(data_len as u32).to_le_bytes());
write_dtype(tensor.dtype(), buffer);
let shape_len = b_shape.len();
if shape_len > u32::MAX as usize {
candle_core::bail!(
"Tensor has too many dimensions for UQFF format: {} exceeds u32::MAX",
shape_len
);
}
buffer.extend((shape_len as u32).to_le_bytes());
for dim in b_shape {
if *dim > u32::MAX as usize {
candle_core::bail!(
"Tensor dimension too large for UQFF format: {} exceeds u32::MAX",
dim
);
}
buffer.extend((*dim as u32).to_le_bytes());
}
buffer.extend(bias);
Ok(())
}
pub(crate) fn deserialize_tensor<R: std::io::Read>(
buffer: &mut R,
device: &Device,
) -> Result<Tensor> {
let data_len = buffer.read_u32::<LittleEndian>()? as usize;
let dtype = read_dtype(buffer)?;
let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
let mut dims = Vec::with_capacity(n_dims);
for _ in 0..n_dims {
dims.push(buffer.read_u32::<LittleEndian>()? as usize)
}
let mut tensor_data = vec![0; data_len];
buffer.read_exact(&mut tensor_data)?;
match dtype {
DType::F16 => bytes_to_data::<f16>(&tensor_data, &dims, device),
DType::BF16 => bytes_to_data::<bf16>(&tensor_data, &dims, device),
DType::F32 => bytes_to_data::<f32>(&tensor_data, &dims, device),
DType::F64 => bytes_to_data::<f64>(&tensor_data, &dims, device),
DType::I32 => bytes_to_data::<i32>(&tensor_data, &dims, device),
DType::I64 => bytes_to_data::<i64>(&tensor_data, &dims, device),
DType::I16 => bytes_to_data::<i16>(&tensor_data, &dims, device),
DType::U32 => bytes_to_data::<u32>(&tensor_data, &dims, device),
DType::U8 => bytes_to_data::<u8>(&tensor_data, &dims, device),
DType::F8E4M3 => bytes_to_data::<F8E4M3>(&tensor_data, &dims, device),
DType::F4 | DType::F6E3M2 | DType::F6E2M3 | DType::F8E8M0 => {
candle_core::bail!("f4/f6e3m2/f6e2m3/f8e8m0 tensors cannot be deserialized.")
}
other => {
candle_core::bail!("Unsupported dtype for UQFF tensor deserialization: {other:?}")
}
}
}
pub(crate) fn fake_deserialize_tensor<R: std::io::Read + std::io::Seek>(
buffer: &mut R,
) -> Result<()> {
let data_len = buffer.read_u32::<LittleEndian>()? as usize;
let _dtype = read_dtype(buffer)?;
let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
let mut dims = Vec::with_capacity(n_dims);
for _ in 0..n_dims {
dims.push(buffer.read_u32::<LittleEndian>()? as usize)
}
buffer.seek_relative(data_len as i64)?;
Ok(())
}
fn data_to_bytes<T: WithDType>(mut vs: Vec<T>) -> Vec<u8> {
let size_in_bytes = T::DTYPE.size_in_bytes();
let length = vs.len() * size_in_bytes;
let capacity = vs.capacity() * size_in_bytes;
let ptr = vs.as_mut_ptr() as *mut u8;
std::mem::forget(vs);
unsafe { Vec::from_raw_parts(ptr, length, capacity) }
}
fn bytes_to_data<T: WithDType>(
data: &[u8],
shape: &[usize],
device: &candle_core::Device,
) -> Result<Tensor> {
let size_in_bytes = T::DTYPE.size_in_bytes();
let elem_count = data.len() / size_in_bytes;
if (data.as_ptr() as usize).is_multiple_of(size_in_bytes) {
let data: &[T] =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
Tensor::from_slice(data, shape, device)
} else {
let mut c: Vec<T> = Vec::with_capacity(elem_count);
unsafe {
std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
c.set_len(elem_count)
}
Tensor::from_slice(&c, shape, device)
}
}
#[cfg(test)]
mod tests {
#[test]
fn dtype_variant_count_unchanged() {
assert_eq!(
std::mem::size_of::<candle_core::DType>(),
1,
"DType repr size changed, check if the discriminant size is the same"
);
const EXPECTED_VARIANTS: usize = 14;
let count = [
candle_core::DType::U8,
candle_core::DType::U32,
candle_core::DType::I16,
candle_core::DType::I32,
candle_core::DType::I64,
candle_core::DType::BF16,
candle_core::DType::F16,
candle_core::DType::F32,
candle_core::DType::F64,
candle_core::DType::F8E4M3,
candle_core::DType::F6E2M3,
candle_core::DType::F6E3M2,
candle_core::DType::F4,
candle_core::DType::F8E8M0,
]
.len();
assert_eq!(
count, EXPECTED_VARIANTS,
"Update this list and the UQFF match arms when DType variants change"
);
}
}