use alloc::sync::Arc;
use alloc::vec::Vec;
use core::fmt;
use burn_backend::{DType, Element, TensorData, TensorMetadata};
use burn_std::{Bytes, Shape, bf16, f16};
use crate::layout::Layout;
#[derive(Clone)]
pub struct FlexTensor {
data: Arc<Bytes>,
layout: Layout,
dtype: DType,
}
impl fmt::Debug for FlexTensor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FlexTensor")
.field("shape", self.layout.shape())
.field("dtype", &self.dtype)
.field("contiguous", &self.layout.is_contiguous())
.field("unique", &self.is_unique())
.finish()
}
}
impl FlexTensor {
pub fn new(data: Bytes, layout: Layout, dtype: DType) -> Self {
Self {
data: Arc::new(data),
layout,
dtype,
}
}
pub fn from_data(data: TensorData) -> Self {
let shape = data.shape.clone();
let layout = Layout::contiguous(shape);
let dtype = data.dtype;
Self {
data: Arc::new(data.bytes),
layout,
dtype,
}
}
pub fn into_data(self) -> TensorData {
if self.layout.is_contiguous() && self.layout.start_offset() == 0 {
match Arc::try_unwrap(self.data) {
Ok(bytes) => TensorData {
bytes,
shape: self.layout.shape().clone(),
dtype: self.dtype,
},
Err(arc) => {
let bytes = Bytes::from_bytes_vec((*arc).to_vec());
TensorData {
bytes,
shape: self.layout.shape().clone(),
dtype: self.dtype,
}
}
}
} else {
self.to_contiguous().into_data()
}
}
#[inline]
pub fn is_unique(&self) -> bool {
Arc::strong_count(&self.data) == 1
}
pub fn layout(&self) -> &Layout {
&self.layout
}
pub fn with_layout(self, layout: Layout) -> Self {
Self {
data: self.data,
layout,
dtype: self.dtype,
}
}
pub fn dtype(&self) -> DType {
self.dtype
}
pub fn is_contiguous(&self) -> bool {
self.layout.is_contiguous()
}
pub fn bytes(&self) -> &[u8] {
&self.data
}
pub fn data_arc(&self) -> Arc<Bytes> {
Arc::clone(&self.data)
}
pub fn from_arc(data: Arc<Bytes>, layout: Layout, dtype: DType) -> Self {
Self {
data,
layout,
dtype,
}
}
pub fn storage<E: Element + bytemuck::Pod>(&self) -> &[E] {
debug_assert!(
E::dtype() == self.dtype
|| (matches!(self.dtype, DType::Bool(_)) && E::dtype() == DType::U8),
"storage: dtype mismatch (expected {:?}, got {:?})",
self.dtype,
E::dtype()
);
bytemuck::cast_slice(&self.data)
}
pub fn storage_mut<E: Element + bytemuck::Pod>(&mut self) -> &mut [E] {
debug_assert!(
E::dtype() == self.dtype
|| (matches!(self.dtype, DType::Bool(_)) && E::dtype() == DType::U8),
"storage_mut: dtype mismatch (expected {:?}, got {:?})",
self.dtype,
E::dtype()
);
let bytes = Arc::make_mut(&mut self.data);
bytemuck::cast_slice_mut(bytes)
}
pub fn try_storage_mut<E: Element + bytemuck::Pod>(&mut self) -> Option<&mut [E]> {
debug_assert!(
E::dtype() == self.dtype
|| (matches!(self.dtype, DType::Bool(_)) && E::dtype() == DType::U8),
"try_storage_mut: dtype mismatch (expected {:?}, got {:?})",
self.dtype,
E::dtype()
);
if self.is_unique() {
let bytes = Arc::get_mut(&mut self.data)?;
Some(bytemuck::cast_slice_mut(bytes))
} else {
None
}
}
pub fn as_slice<E: Element + bytemuck::Pod>(&self) -> Option<&[E]> {
if E::dtype() != self.dtype {
return None;
}
let storage: &[E] = self.storage();
self.layout
.contiguous_offsets()
.map(|(start, end)| &storage[start..end])
}
pub fn empty(shape: Shape, dtype: DType) -> Self {
let num_elements = shape.num_elements();
let elem_size = dtype_size(dtype);
let bytes = Bytes::from_bytes_vec(alloc::vec![0u8; num_elements * elem_size]);
let layout = Layout::contiguous(shape);
Self {
data: Arc::new(bytes),
layout,
dtype,
}
}
pub fn zeros(shape: Shape, dtype: DType) -> Self {
Self::empty(shape, dtype)
}
pub fn to_contiguous(&self) -> Self {
if self.is_contiguous() && self.layout.start_offset() == 0 {
return self.clone();
}
match self.dtype {
DType::F64 => self.copy_contiguous::<f64>(),
DType::F32 => self.copy_contiguous::<f32>(),
DType::F16 => self.copy_contiguous::<f16>(),
DType::BF16 => self.copy_contiguous::<bf16>(),
DType::I64 => self.copy_contiguous::<i64>(),
DType::I32 => self.copy_contiguous::<i32>(),
DType::I16 => self.copy_contiguous::<i16>(),
DType::I8 => self.copy_contiguous::<i8>(),
DType::U64 => self.copy_contiguous::<u64>(),
DType::U32 => self.copy_contiguous::<u32>(),
DType::U16 => self.copy_contiguous::<u16>(),
DType::U8 => self.copy_contiguous::<u8>(),
DType::Bool(_) => self.copy_contiguous::<u8>(), _ => panic!("Unsupported dtype for contiguous copy: {:?}", self.dtype),
}
}
fn copy_contiguous<E: Element + bytemuck::Pod>(&self) -> Self {
let src: &[E] = bytemuck::cast_slice(&self.data);
let n = self.layout.num_elements();
let mut dst = Vec::with_capacity(n);
if let Some((rows, cols, row_stride, col_stride)) = self.layout.as_2d_strides() {
let offset = self.layout.start_offset() as isize;
for row in 0..rows {
let row_start = offset + row as isize * row_stride;
for col in 0..cols {
let idx = (row_start + col as isize * col_stride) as usize;
dst.push(src[idx]);
}
}
} else {
for idx in crate::strided_index::StridedIter::new(&self.layout) {
dst.push(src[idx]);
}
}
let bytes = Bytes::from_elems(dst);
let layout = Layout::contiguous(self.layout.shape().clone());
Self {
data: Arc::new(bytes),
layout,
dtype: self.dtype,
}
}
pub fn reshape(&self, new_shape: Shape) -> Self {
debug_assert_eq!(
self.layout.num_elements(),
new_shape.num_elements(),
"reshape must preserve total elements"
);
if let Some(new_layout) = self.layout.reshape(new_shape.clone()) {
Self {
data: Arc::clone(&self.data),
layout: new_layout,
dtype: self.dtype,
}
} else {
self.to_contiguous().reshape(new_shape)
}
}
pub fn transpose(&self, dim1: usize, dim2: usize) -> Self {
Self {
data: Arc::clone(&self.data),
layout: self.layout.transpose(dim1, dim2),
dtype: self.dtype,
}
}
pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Self {
Self {
data: Arc::clone(&self.data),
layout: self.layout.narrow(dim, start, len),
dtype: self.dtype,
}
}
pub fn permute(&self, axes: &[usize]) -> Self {
Self {
data: Arc::clone(&self.data),
layout: self.layout.permute(axes),
dtype: self.dtype,
}
}
}
impl TensorMetadata for FlexTensor {
fn dtype(&self) -> DType {
self.dtype
}
fn shape(&self) -> Shape {
self.layout.shape().clone()
}
fn rank(&self) -> usize {
self.layout.num_dims()
}
}
fn dtype_size(dtype: DType) -> usize {
match dtype {
DType::F64 | DType::I64 | DType::U64 => 8,
DType::F32 | DType::I32 | DType::U32 => 4,
DType::F16 | DType::BF16 | DType::I16 | DType::U16 => 2,
DType::I8 | DType::U8 | DType::Bool(_) => 1,
_ => panic!("Unsupported dtype: {:?}", dtype),
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
#[test]
fn test_from_data_roundtrip() {
let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]);
let tensor = FlexTensor::from_data(data.clone());
let result = tensor.into_data();
assert_eq!(data.shape, result.shape);
assert_eq!(data.dtype, result.dtype);
}
#[test]
fn test_reshape() {
let data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
let tensor = FlexTensor::from_data(data);
let reshaped = tensor.reshape(Shape::from(vec![3, 2]));
assert_eq!(reshaped.shape().to_vec(), vec![3, 2]);
}
#[test]
fn test_transpose() {
let data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
let tensor = FlexTensor::from_data(data);
let transposed = tensor.transpose(0, 1);
assert_eq!(transposed.shape().to_vec(), vec![3, 2]);
assert!(!transposed.is_contiguous());
}
#[test]
fn test_clone_is_cheap() {
let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0]);
let tensor = FlexTensor::from_data(data);
assert!(tensor.is_unique());
let cloned = tensor.clone();
assert!(!tensor.is_unique());
assert!(!cloned.is_unique());
assert!(core::ptr::eq(
tensor.bytes().as_ptr(),
cloned.bytes().as_ptr()
));
}
#[test]
fn test_cow_on_mutation() {
let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0]);
let tensor = FlexTensor::from_data(data);
let mut cloned = tensor.clone();
assert!(!tensor.is_unique());
assert!(!cloned.is_unique());
let storage: &mut [f32] = cloned.storage_mut();
storage[0] = 99.0;
assert!(tensor.is_unique());
assert!(cloned.is_unique());
assert_ne!(tensor.bytes().as_ptr(), cloned.bytes().as_ptr());
assert_eq!(tensor.storage::<f32>()[0], 1.0);
assert_eq!(cloned.storage::<f32>()[0], 99.0);
}
}