use serde::{Deserialize, Serialize};
use serde::{Deserializer, Serializer};
use std::sync::Arc;
use crate::buffer::{BufferHandle, CpuBuffer};
use crate::dtype::DType;
use crate::error::{Result, SapientError};
use crate::shape::Shape;
#[derive(Debug, Clone)]
pub struct Tensor {
shape: Shape,
dtype: DType,
strides: Vec<usize>, buffer: BufferHandle,
offset: usize,
}
impl Tensor {
pub fn zeros(shape: impl Into<Shape>, dtype: DType) -> Result<Self> {
let shape = shape.into();
shape.validate()?;
let numel = shape.numel();
let strides = shape.strides();
let buffer = BufferHandle::new(CpuBuffer::zeros(numel, dtype)?);
Ok(Self {
shape,
dtype,
strides,
buffer,
offset: 0,
})
}
pub fn from_f32(data: &[f32], shape: impl Into<Shape>) -> Result<Self> {
let shape = shape.into();
shape.validate()?;
if data.len() != shape.numel() {
return Err(SapientError::ShapeMismatch {
expected: shape.dims().to_vec(),
got: vec![data.len()],
});
}
let strides = shape.strides();
let buffer = BufferHandle::new(CpuBuffer::from_f32_slice(data)?);
Ok(Self {
shape,
dtype: DType::F32,
strides,
buffer,
offset: 0,
})
}
pub fn from_bf16_bytes(data: &[u8], shape: impl Into<Shape>) -> Result<Self> {
let shape = shape.into();
shape.validate()?;
let expected_bytes = shape.numel() * 2;
if data.len() != expected_bytes {
return Err(SapientError::ShapeMismatch {
expected: shape.dims().to_vec(),
got: vec![data.len() / 2],
});
}
let strides = shape.strides();
let buffer = BufferHandle::new(CpuBuffer::from_bytes_slice(data)?);
Ok(Self {
shape,
dtype: DType::BF16,
strides,
buffer,
offset: 0,
})
}
pub fn from_f16_bytes(data: &[u8], shape: impl Into<Shape>) -> Result<Self> {
let shape = shape.into();
shape.validate()?;
let expected_bytes = shape.numel() * 2;
if data.len() != expected_bytes {
return Err(SapientError::ShapeMismatch {
expected: shape.dims().to_vec(),
got: vec![data.len() / 2],
});
}
let strides = shape.strides();
let buffer = BufferHandle::new(CpuBuffer::from_bytes_slice(data)?);
Ok(Self {
shape,
dtype: DType::F16,
strides,
buffer,
offset: 0,
})
}
pub fn from_quant_bytes(data: &[u8], shape: impl Into<Shape>, dtype: DType) -> Result<Self> {
if !dtype.is_quantized() {
return Err(SapientError::TypeMismatch {
expected: "a quantized dtype (Q4_0 or Q8_0)".into(),
got: dtype.to_string(),
});
}
let shape = shape.into();
shape.validate()?;
let numel = shape.numel();
let expected_bytes = dtype.byte_count(numel);
if data.len() != expected_bytes {
return Err(SapientError::ShapeMismatch {
expected: vec![expected_bytes],
got: vec![data.len()],
});
}
let strides = shape.strides();
let buffer = BufferHandle::new(CpuBuffer::from_bytes_slice(data)?);
Ok(Self {
shape,
dtype,
strides,
buffer,
offset: 0,
})
}
pub fn scalar_f32(v: f32) -> Result<Self> {
Self::from_f32(&[v], Shape::scalar())
}
pub fn from_buffer(
shape: impl Into<Shape>,
dtype: DType,
buffer: BufferHandle,
offset: usize,
) -> Result<Self> {
let shape = shape.into();
shape.validate()?;
let required = dtype.byte_count(shape.numel());
if buffer.len() < offset + required {
return Err(SapientError::BufferSizeMismatch {
expected: offset + required,
got: buffer.len(),
});
}
let strides = shape.strides();
Ok(Self {
shape,
dtype,
strides,
buffer,
offset,
})
}
pub fn shape(&self) -> &Shape {
&self.shape
}
pub fn dtype(&self) -> DType {
self.dtype
}
pub fn ndim(&self) -> usize {
self.shape.ndim()
}
pub fn numel(&self) -> usize {
self.shape.numel()
}
pub fn strides(&self) -> &[usize] {
&self.strides
}
pub fn buffer(&self) -> &BufferHandle {
&self.buffer
}
pub fn offset(&self) -> usize {
self.offset
}
pub fn is_scalar(&self) -> bool {
self.shape.is_scalar() || self.numel() == 1
}
pub fn is_contiguous(&self) -> bool {
self.strides == self.shape.strides() && self.offset == 0
}
pub fn as_bytes(&self) -> &[u8] {
let bytes = self.buffer.as_bytes();
if self.dtype.is_quantized() {
let end = self.offset + self.dtype.byte_count(self.numel());
&bytes[self.offset..end]
} else {
&bytes[self.offset..]
}
}
pub fn as_quant_blocks(&self) -> &[u8] {
assert!(
self.dtype.is_quantized(),
"as_quant_blocks() called on non-quantized tensor (dtype = {})",
self.dtype
);
self.as_bytes()
}
pub fn as_f32_slice(&self) -> &[f32] {
assert_eq!(
self.dtype,
DType::F32,
"Tensor dtype is not F32 — call to_f32_vec() instead"
);
let bytes = self.as_bytes();
assert_eq!(bytes.len() % 4, 0);
unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const f32, bytes.len() / 4) }
}
pub fn to_f32_cow(&self) -> std::borrow::Cow<'_, [f32]> {
if self.dtype == DType::F32 {
std::borrow::Cow::Borrowed(self.as_f32_slice())
} else {
std::borrow::Cow::Owned(self.to_f32_vec())
}
}
pub fn to_f32_vec(&self) -> Vec<f32> {
use crate::dtype::{Q4_0_BLOCK_BYTES, Q8_0_BLOCK_BYTES, QUANT_BLOCK_SIZE};
match self.dtype {
DType::F32 => self.as_f32_slice().to_vec(),
DType::BF16 => {
let bytes = self.as_bytes();
bytes
.chunks_exact(2)
.map(|c| f32::from(half::bf16::from_le_bytes(c.try_into().unwrap())))
.collect()
}
DType::F16 => {
let bytes = self.as_bytes();
bytes
.chunks_exact(2)
.map(|c| half::f16::from_le_bytes(c.try_into().unwrap()).to_f32())
.collect()
}
DType::Q4_0 => {
let numel = self.numel();
let bytes = self.as_bytes();
let mut out = vec![0.0f32; numel];
for (b, block) in bytes.chunks_exact(Q4_0_BLOCK_BYTES).enumerate() {
let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
for j in 0..QUANT_BLOCK_SIZE / 2 {
let byte = block[2 + j];
let lo = (byte & 0x0f) as i32 - 8;
let hi = (byte >> 4) as i32 - 8;
out[b * QUANT_BLOCK_SIZE + j] = lo as f32 * d;
out[b * QUANT_BLOCK_SIZE + j + QUANT_BLOCK_SIZE / 2] = hi as f32 * d;
}
}
out
}
DType::Q8_0 => {
let numel = self.numel();
let bytes = self.as_bytes();
let mut out = vec![0.0f32; numel];
for (b, block) in bytes.chunks_exact(Q8_0_BLOCK_BYTES).enumerate() {
let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
for j in 0..QUANT_BLOCK_SIZE {
out[b * QUANT_BLOCK_SIZE + j] = block[2 + j] as i8 as f32 * d;
}
}
out
}
_ => self.as_f32_slice().to_vec(), }
}
pub fn to_f32_tensor(&self) -> Result<Tensor> {
match self.dtype {
DType::F32 => Ok(self.clone()),
_ => Tensor::from_f32(&self.to_f32_vec(), self.shape.clone()),
}
}
pub fn as_f32_slice_mut(&mut self) -> Result<&mut [f32]> {
if self.dtype != DType::F32 {
return Err(SapientError::internal("Tensor dtype is not F32"));
}
let offset = self.offset;
let buf = Arc::get_mut(&mut self.buffer.0)
.ok_or_else(|| SapientError::internal("Cannot mutate shared tensor buffer"))?;
let bytes = buf.as_bytes_mut();
let bytes = &mut bytes[offset..];
if bytes.len() % 4 != 0 {
return Err(SapientError::internal("Buffer length not a multiple of 4"));
}
Ok(unsafe {
std::slice::from_raw_parts_mut(bytes.as_mut_ptr() as *mut f32, bytes.len() / 4)
})
}
pub fn reshape(&self, new_shape: impl Into<Shape>) -> Result<Tensor> {
let new_shape = self.shape.reshape(new_shape.into().dims().to_vec())?;
let strides = new_shape.strides();
Ok(Tensor {
shape: new_shape,
dtype: self.dtype,
strides,
buffer: self.buffer.clone(),
offset: self.offset,
})
}
pub fn t(&self) -> Result<Tensor> {
if self.ndim() != 2 {
return Err(SapientError::internal("t() requires a 2-D tensor"));
}
let mut dims = self.shape.dims().to_vec();
let mut strides = self.strides.clone();
dims.swap(0, 1);
strides.swap(0, 1);
Ok(Tensor {
shape: Shape(dims),
dtype: self.dtype,
strides,
buffer: self.buffer.clone(),
offset: self.offset,
})
}
pub fn slice_axis(&self, axis: usize, start: usize, end: usize) -> Result<Tensor> {
let mut dims = self.shape.dims().to_vec();
if axis >= dims.len() {
return Err(SapientError::internal("slice axis out of bounds"));
}
if start > end || end > dims[axis] {
return Err(SapientError::internal("slice range out of bounds"));
}
dims[axis] = end - start;
let offset = self.offset + start * self.strides[axis] * self.dtype.element_size();
Ok(Tensor {
shape: Shape(dims),
dtype: self.dtype,
strides: self.strides.clone(),
buffer: self.buffer.clone(),
offset,
})
}
pub fn byte_size(&self) -> usize {
self.dtype.byte_count(self.numel())
}
}
impl std::fmt::Display for Tensor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Tensor(shape={}, dtype={}, device={})",
self.shape,
self.dtype,
self.buffer.0.device()
)
}
}
#[derive(Serialize, Deserialize)]
struct TensorProxy {
shape: Shape,
dtype: DType,
data: Vec<f32>,
}
impl Serialize for Tensor {
fn serialize<S: Serializer>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> {
let data: Vec<f32> = if self.dtype == DType::F32 {
self.as_f32_slice().to_vec()
} else {
vec![] };
TensorProxy {
shape: self.shape.clone(),
dtype: self.dtype,
data,
}
.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for Tensor {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> std::result::Result<Self, D::Error> {
let proxy = TensorProxy::deserialize(deserializer)?;
if proxy.data.is_empty() {
Tensor::zeros(proxy.shape, proxy.dtype).map_err(serde::de::Error::custom)
} else {
Tensor::from_f32(&proxy.data, proxy.shape).map_err(serde::de::Error::custom)
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorMeta {
pub shape: Shape,
pub dtype: DType,
}
impl From<&Tensor> for TensorMeta {
fn from(t: &Tensor) -> Self {
Self {
shape: t.shape.clone(),
dtype: t.dtype,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn zeros_dtype_shape() {
let t = Tensor::zeros(vec![2, 3], DType::F32).unwrap();
assert_eq!(t.shape().dims(), &[2, 3]);
assert_eq!(t.dtype(), DType::F32);
assert_eq!(t.numel(), 6);
}
#[test]
fn from_f32_roundtrip() {
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let t = Tensor::from_f32(&data, vec![2, 3]).unwrap();
assert_eq!(t.as_f32_slice(), data.as_slice());
}
#[test]
fn reshape_preserves_data() {
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let t = Tensor::from_f32(&data, vec![2, 3]).unwrap();
let r = t.reshape(vec![3, 2]).unwrap();
assert_eq!(r.shape().dims(), &[3, 2]);
assert_eq!(r.as_f32_slice(), data.as_slice());
}
#[test]
fn reshape_wrong_numel() {
let t = Tensor::zeros(vec![2, 3], DType::F32).unwrap();
assert!(t.reshape(vec![5]).is_err());
}
#[test]
fn transpose_2d() {
let t = Tensor::zeros(vec![3, 4], DType::F32).unwrap();
let t2 = t.t().unwrap();
assert_eq!(t2.shape().dims(), &[4, 3]);
}
#[test]
fn byte_size() {
let t = Tensor::zeros(vec![4, 4], DType::F32).unwrap();
assert_eq!(t.byte_size(), 64);
}
}