use serde::{Deserialize, Serialize};
use serde::{Deserializer, Serializer};
use std::sync::Arc;
use crate::buffer::{BufferHandle, CpuBuffer};
use crate::dtype::DType;
use crate::error::{Result, SapientError};
use crate::shape::Shape;
#[derive(Debug, Clone)]
pub struct Tensor {
shape: Shape,
dtype: DType,
strides: Vec<usize>, buffer: BufferHandle,
offset: usize,
}
impl Tensor {
pub fn zeros(shape: impl Into<Shape>, dtype: DType) -> Result<Self> {
let shape = shape.into();
shape.validate()?;
let numel = shape.numel();
let strides = shape.strides();
let buffer = BufferHandle::new(CpuBuffer::zeros(numel, dtype)?);
Ok(Self {
shape,
dtype,
strides,
buffer,
offset: 0,
})
}
pub fn from_f32_vec(data: Vec<f32>, shape: impl Into<Shape>) -> Result<Self> {
let shape = shape.into();
shape.validate()?;
if data.len() != shape.numel() {
return Err(SapientError::ShapeMismatch {
expected: shape.dims().to_vec(),
got: vec![data.len()],
});
}
let strides = shape.strides();
let buffer = BufferHandle::new(CpuBuffer::from_f32_vec(data)?);
Ok(Self {
shape,
dtype: DType::F32,
strides,
buffer,
offset: 0,
})
}
pub fn from_f32(data: &[f32], shape: impl Into<Shape>) -> Result<Self> {
let shape = shape.into();
shape.validate()?;
if data.len() != shape.numel() {
return Err(SapientError::ShapeMismatch {
expected: shape.dims().to_vec(),
got: vec![data.len()],
});
}
let strides = shape.strides();
let buffer = BufferHandle::new(CpuBuffer::from_f32_slice(data)?);
Ok(Self {
shape,
dtype: DType::F32,
strides,
buffer,
offset: 0,
})
}
pub fn from_bf16_bytes(data: &[u8], shape: impl Into<Shape>) -> Result<Self> {
let shape = shape.into();
shape.validate()?;
let expected_bytes = shape.numel() * 2;
if data.len() != expected_bytes {
return Err(SapientError::ShapeMismatch {
expected: shape.dims().to_vec(),
got: vec![data.len() / 2],
});
}
let strides = shape.strides();
let buffer = BufferHandle::new(CpuBuffer::from_bytes_slice(data)?);
Ok(Self {
shape,
dtype: DType::BF16,
strides,
buffer,
offset: 0,
})
}
pub fn from_f16_bytes(data: &[u8], shape: impl Into<Shape>) -> Result<Self> {
let shape = shape.into();
shape.validate()?;
let expected_bytes = shape.numel() * 2;
if data.len() != expected_bytes {
return Err(SapientError::ShapeMismatch {
expected: shape.dims().to_vec(),
got: vec![data.len() / 2],
});
}
let strides = shape.strides();
let buffer = BufferHandle::new(CpuBuffer::from_bytes_slice(data)?);
Ok(Self {
shape,
dtype: DType::F16,
strides,
buffer,
offset: 0,
})
}
pub fn from_quant_bytes(data: &[u8], shape: impl Into<Shape>, dtype: DType) -> Result<Self> {
if !dtype.is_quantized() {
return Err(SapientError::TypeMismatch {
expected: "a quantized dtype (Q4_0, Q8_0, Q4_K, Q5_K, Q6_K)".into(),
got: dtype.to_string(),
});
}
let shape = shape.into();
shape.validate()?;
let numel = shape.numel();
let expected_bytes = dtype.byte_count(numel);
if data.len() != expected_bytes {
return Err(SapientError::ShapeMismatch {
expected: vec![expected_bytes],
got: vec![data.len()],
});
}
let strides = shape.strides();
let buffer = BufferHandle::new(CpuBuffer::from_bytes_slice(data)?);
Ok(Self {
shape,
dtype,
strides,
buffer,
offset: 0,
})
}
pub fn scalar_f32(v: f32) -> Result<Self> {
Self::from_f32(&[v], Shape::scalar())
}
pub fn from_buffer(
shape: impl Into<Shape>,
dtype: DType,
buffer: BufferHandle,
offset: usize,
) -> Result<Self> {
let shape = shape.into();
shape.validate()?;
let required = dtype.byte_count(shape.numel());
if buffer.len() < offset + required {
return Err(SapientError::BufferSizeMismatch {
expected: offset + required,
got: buffer.len(),
});
}
let strides = shape.strides();
Ok(Self {
shape,
dtype,
strides,
buffer,
offset,
})
}
pub fn shape(&self) -> &Shape {
&self.shape
}
pub fn dtype(&self) -> DType {
self.dtype
}
pub fn ndim(&self) -> usize {
self.shape.ndim()
}
pub fn numel(&self) -> usize {
self.shape.numel()
}
pub fn strides(&self) -> &[usize] {
&self.strides
}
pub fn buffer(&self) -> &BufferHandle {
&self.buffer
}
pub fn offset(&self) -> usize {
self.offset
}
pub fn is_scalar(&self) -> bool {
self.shape.is_scalar() || self.numel() == 1
}
pub fn is_contiguous(&self) -> bool {
self.strides == self.shape.strides() && self.offset == 0
}
pub fn as_bytes(&self) -> &[u8] {
let bytes = self.buffer.as_bytes();
if self.dtype.is_quantized() {
let end = self.offset + self.dtype.byte_count(self.numel());
&bytes[self.offset..end]
} else {
&bytes[self.offset..]
}
}
pub fn as_quant_blocks(&self) -> &[u8] {
assert!(
self.dtype.is_quantized(),
"as_quant_blocks() called on non-quantized tensor (dtype = {})",
self.dtype
);
self.as_bytes()
}
pub fn as_f32_slice(&self) -> &[f32] {
assert_eq!(
self.dtype,
DType::F32,
"Tensor dtype is not F32 — call to_f32_vec() instead"
);
let bytes = self.as_bytes();
assert_eq!(bytes.len() % 4, 0);
unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const f32, bytes.len() / 4) }
}
pub fn to_contiguous_f32_vec(&self) -> Vec<f32> {
let numel = self.numel();
if self.is_contiguous() {
match self.dtype {
DType::F32 => self.as_f32_slice()[..numel].to_vec(),
_ => {
let v = self.to_f32_vec();
v[..numel.min(v.len())].to_vec()
}
}
} else {
let raw: Vec<f32> = match self.dtype {
DType::F32 => self.as_f32_slice().to_vec(),
_ => self.to_f32_vec(),
};
let dims = self.shape.dims();
let strides = &self.strides; let mut out = vec![0.0f32; numel];
for (flat, dst) in out.iter_mut().enumerate() {
let mut rem = flat;
let mut src = 0usize;
for d in (0..dims.len()).rev() {
let idx_d = rem % dims[d];
rem /= dims[d];
src += idx_d * strides[d];
}
*dst = *raw.get(src).unwrap_or(&0.0);
}
out
}
}
pub fn to_f32_cow(&self) -> std::borrow::Cow<'_, [f32]> {
if self.dtype == DType::F32 {
std::borrow::Cow::Borrowed(self.as_f32_slice())
} else {
std::borrow::Cow::Owned(self.to_f32_vec())
}
}
pub fn to_f32_vec(&self) -> Vec<f32> {
use crate::dtype::{
K_QUANT_BLOCK_SIZE, Q4_0_BLOCK_BYTES, Q4_K_BLOCK_BYTES, Q5_K_BLOCK_BYTES,
Q6_K_BLOCK_BYTES, Q8_0_BLOCK_BYTES, QUANT_BLOCK_SIZE,
};
match self.dtype {
DType::F32 => self.as_f32_slice().to_vec(),
DType::BF16 => {
let bytes = self.as_bytes();
bytes
.chunks_exact(2)
.map(|c| f32::from(half::bf16::from_le_bytes(c.try_into().unwrap())))
.collect()
}
DType::F16 => {
let bytes = self.as_bytes();
bytes
.chunks_exact(2)
.map(|c| half::f16::from_le_bytes(c.try_into().unwrap()).to_f32())
.collect()
}
DType::Q4_0 => {
let numel = self.numel();
let bytes = self.as_bytes();
let mut out = vec![0.0f32; numel];
for (b, block) in bytes.chunks_exact(Q4_0_BLOCK_BYTES).enumerate() {
let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
for j in 0..QUANT_BLOCK_SIZE / 2 {
let byte = block[2 + j];
let lo = (byte & 0x0f) as i32 - 8;
let hi = (byte >> 4) as i32 - 8;
out[b * QUANT_BLOCK_SIZE + j] = lo as f32 * d;
out[b * QUANT_BLOCK_SIZE + j + QUANT_BLOCK_SIZE / 2] = hi as f32 * d;
}
}
out
}
DType::Q8_0 => {
let numel = self.numel();
let bytes = self.as_bytes();
let mut out = vec![0.0f32; numel];
for (b, block) in bytes.chunks_exact(Q8_0_BLOCK_BYTES).enumerate() {
let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
for j in 0..QUANT_BLOCK_SIZE {
out[b * QUANT_BLOCK_SIZE + j] = block[2 + j] as i8 as f32 * d;
}
}
out
}
DType::Q4_K => {
let numel = self.numel();
let bytes = self.as_bytes();
let mut out = vec![0.0f32; numel];
let mut out_idx = 0usize;
for block in bytes.chunks_exact(Q4_K_BLOCK_BYTES) {
let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
let dmin = half::f16::from_le_bytes([block[2], block[3]]).to_f32();
let scales = &block[4..16];
let qs = &block[16..Q4_K_BLOCK_BYTES];
let mut q_off = 0usize;
let mut is = 0usize;
for _ in 0..(K_QUANT_BLOCK_SIZE / 64) {
let (sc1, m1) = Self::get_scale_min_k4(is, scales);
let d1 = d * sc1 as f32;
let m1v = dmin * m1 as f32;
let (sc2, m2) = Self::get_scale_min_k4(is + 1, scales);
let d2 = d * sc2 as f32;
let m2v = dmin * m2 as f32;
for l in 0..32 {
out[out_idx + l] = d1 * (qs[q_off + l] & 0x0F) as f32 - m1v;
out[out_idx + l + 32] = d2 * (qs[q_off + l] >> 4) as f32 - m2v;
}
out_idx += 64;
q_off += 32;
is += 2;
}
}
out
}
DType::Q5_K => {
let numel = self.numel();
let bytes = self.as_bytes();
let mut out = vec![0.0f32; numel];
let mut out_idx = 0usize;
for block in bytes.chunks_exact(Q5_K_BLOCK_BYTES) {
let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
let dmin = half::f16::from_le_bytes([block[2], block[3]]).to_f32();
let scales = &block[4..16];
let qh = &block[16..48];
let ql = &block[48..Q5_K_BLOCK_BYTES];
let mut ql_off = 0usize;
let mut is = 0usize;
let mut u1: u8 = 1;
let mut u2: u8 = 2;
for _ in 0..(K_QUANT_BLOCK_SIZE / 64) {
let (sc1, m1) = Self::get_scale_min_k4(is, scales);
let d1 = d * sc1 as f32;
let m1v = dmin * m1 as f32;
let (sc2, m2) = Self::get_scale_min_k4(is + 1, scales);
let d2 = d * sc2 as f32;
let m2v = dmin * m2 as f32;
let qh_byte = qh[is / 8];
for l in 0..32usize {
let hi = if qh_byte & u1 != 0 { 16.0f32 } else { 0.0 };
out[out_idx + l] = d1 * ((ql[ql_off + l] & 0x0F) as f32 + hi) - m1v;
let hi2 = if qh_byte & u2 != 0 { 16.0f32 } else { 0.0 };
out[out_idx + l + 32] = d2 * ((ql[ql_off + l] >> 4) as f32 + hi2) - m2v;
}
out_idx += 64;
ql_off += 32;
is += 2;
if is % 8 == 0 {
u1 = 1;
u2 = 2;
} else {
u1 <<= 2;
u2 <<= 2;
}
}
}
out
}
DType::Q6_K => {
let numel = self.numel();
let bytes = self.as_bytes();
let mut out = vec![0.0f32; numel];
let mut out_idx = 0usize;
for block in bytes.chunks_exact(Q6_K_BLOCK_BYTES) {
let ql = &block[0..128];
let qh = &block[128..192];
let sc = &block[192..208];
let d = half::f16::from_le_bytes([block[208], block[209]]).to_f32();
let mut ql_off = 0usize;
let mut qh_off = 0usize;
let mut ib = 0usize;
for _ in 0..(K_QUANT_BLOCK_SIZE / 128) {
for l in 0..32usize {
let q1 = (((ql[ql_off + l] & 0x0F) | ((qh[qh_off + l] & 3) << 4))
as i32
- 32) as f32;
let q2 = (((ql[ql_off + l + 32] & 0x0F)
| (((qh[qh_off + l] >> 2) & 3) << 4))
as i32
- 32) as f32;
let q3 = (((ql[ql_off + l] >> 4) | (((qh[qh_off + l] >> 4) & 3) << 4))
as i32
- 32) as f32;
let q4 = (((ql[ql_off + l + 32] >> 4)
| (((qh[qh_off + l] >> 6) & 3) << 4))
as i32
- 32) as f32;
out[out_idx + l] = d * sc[ib] as i8 as f32 * q1;
out[out_idx + l + 32] = d * sc[ib + 1] as i8 as f32 * q2;
out[out_idx + l + 64] = d * sc[ib + 2] as i8 as f32 * q3;
out[out_idx + l + 96] = d * sc[ib + 3] as i8 as f32 * q4;
}
out_idx += 128;
ql_off += 64;
qh_off += 32;
ib += 4;
}
}
out
}
_ => self.as_f32_slice().to_vec(), }
}
#[inline]
fn get_scale_min_k4(j: usize, scales: &[u8]) -> (u8, u8) {
if j < 4 {
(scales[j] & 63, scales[j + 4] & 63)
} else {
(
(scales[j + 4] & 0x0F) | ((scales[j - 4] >> 6) << 4),
(scales[j + 4] >> 4) | ((scales[j] >> 6) << 4),
)
}
}
pub fn to_f32_tensor(&self) -> Result<Tensor> {
match self.dtype {
DType::F32 => Ok(self.clone()),
_ => Tensor::from_f32(&self.to_f32_vec(), self.shape.clone()),
}
}
pub fn as_bytes_mut(&mut self) -> Result<&mut [u8]> {
let offset = self.offset;
let end = offset + self.dtype.byte_count(self.numel());
let buf = Arc::get_mut(&mut self.buffer.0)
.ok_or_else(|| SapientError::internal("Cannot mutate shared tensor buffer"))?;
let bytes = buf.as_bytes_mut();
Ok(&mut bytes[offset..end])
}
pub fn as_f32_slice_mut(&mut self) -> Result<&mut [f32]> {
if self.dtype != DType::F32 {
return Err(SapientError::internal("Tensor dtype is not F32"));
}
let offset = self.offset;
let buf = Arc::get_mut(&mut self.buffer.0)
.ok_or_else(|| SapientError::internal("Cannot mutate shared tensor buffer"))?;
let bytes = buf.as_bytes_mut();
let bytes = &mut bytes[offset..];
if bytes.len() % 4 != 0 {
return Err(SapientError::internal("Buffer length not a multiple of 4"));
}
Ok(unsafe {
std::slice::from_raw_parts_mut(bytes.as_mut_ptr() as *mut f32, bytes.len() / 4)
})
}
pub fn reshape(&self, new_shape: impl Into<Shape>) -> Result<Tensor> {
let new_shape = self.shape.reshape(new_shape.into().dims().to_vec())?;
let strides = new_shape.strides();
Ok(Tensor {
shape: new_shape,
dtype: self.dtype,
strides,
buffer: self.buffer.clone(),
offset: self.offset,
})
}
pub fn t(&self) -> Result<Tensor> {
if self.ndim() != 2 {
return Err(SapientError::internal("t() requires a 2-D tensor"));
}
let mut dims = self.shape.dims().to_vec();
let mut strides = self.strides.clone();
dims.swap(0, 1);
strides.swap(0, 1);
Ok(Tensor {
shape: Shape(dims),
dtype: self.dtype,
strides,
buffer: self.buffer.clone(),
offset: self.offset,
})
}
pub fn slice_axis(&self, axis: usize, start: usize, end: usize) -> Result<Tensor> {
let mut dims = self.shape.dims().to_vec();
if axis >= dims.len() {
return Err(SapientError::internal("slice axis out of bounds"));
}
if start > end || end > dims[axis] {
return Err(SapientError::internal("slice range out of bounds"));
}
dims[axis] = end - start;
let offset = self.offset + start * self.strides[axis] * self.dtype.element_size();
Ok(Tensor {
shape: Shape(dims),
dtype: self.dtype,
strides: self.strides.clone(),
buffer: self.buffer.clone(),
offset,
})
}
pub fn byte_size(&self) -> usize {
self.dtype.byte_count(self.numel())
}
}
impl std::fmt::Display for Tensor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Tensor(shape={}, dtype={}, device={})",
self.shape,
self.dtype,
self.buffer.0.device()
)
}
}
#[derive(Serialize, Deserialize)]
struct TensorProxy {
shape: Shape,
dtype: DType,
data: Vec<f32>,
}
impl Serialize for Tensor {
fn serialize<S: Serializer>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> {
let data: Vec<f32> = if self.dtype == DType::F32 {
self.as_f32_slice().to_vec()
} else {
vec![] };
TensorProxy {
shape: self.shape.clone(),
dtype: self.dtype,
data,
}
.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for Tensor {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> std::result::Result<Self, D::Error> {
let proxy = TensorProxy::deserialize(deserializer)?;
if proxy.data.is_empty() {
Tensor::zeros(proxy.shape, proxy.dtype).map_err(serde::de::Error::custom)
} else {
Tensor::from_f32(&proxy.data, proxy.shape).map_err(serde::de::Error::custom)
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorMeta {
pub shape: Shape,
pub dtype: DType,
}
impl From<&Tensor> for TensorMeta {
fn from(t: &Tensor) -> Self {
Self {
shape: t.shape.clone(),
dtype: t.dtype,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn zeros_dtype_shape() {
let t = Tensor::zeros(vec![2, 3], DType::F32).unwrap();
assert_eq!(t.shape().dims(), &[2, 3]);
assert_eq!(t.dtype(), DType::F32);
assert_eq!(t.numel(), 6);
}
#[test]
fn from_f32_roundtrip() {
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let t = Tensor::from_f32(&data, vec![2, 3]).unwrap();
assert_eq!(t.as_f32_slice(), data.as_slice());
}
#[test]
fn reshape_preserves_data() {
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let t = Tensor::from_f32(&data, vec![2, 3]).unwrap();
let r = t.reshape(vec![3, 2]).unwrap();
assert_eq!(r.shape().dims(), &[3, 2]);
assert_eq!(r.as_f32_slice(), data.as_slice());
}
#[test]
fn reshape_wrong_numel() {
let t = Tensor::zeros(vec![2, 3], DType::F32).unwrap();
assert!(t.reshape(vec![5]).is_err());
}
#[test]
fn transpose_2d() {
let t = Tensor::zeros(vec![3, 4], DType::F32).unwrap();
let t2 = t.t().unwrap();
assert_eq!(t2.shape().dims(), &[4, 3]);
}
#[test]
fn byte_size() {
let t = Tensor::zeros(vec![4, 4], DType::F32).unwrap();
assert_eq!(t.byte_size(), 64);
}
}