#[cfg(target_has_atomic = "ptr")]
use alloc::sync::Arc;
use alloc::vec::Vec;
use core::fmt;
#[cfg(not(target_has_atomic = "ptr"))]
use portable_atomic_util::Arc;
use burn_backend::{DType, Element, TensorData, TensorMetadata};
use burn_std::{Bytes, Shape, bf16, f16};
use crate::layout::Layout;
#[derive(Clone)]
pub struct FlexTensor {
data: Arc<Bytes>,
layout: Layout,
dtype: DType,
}
impl fmt::Debug for FlexTensor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FlexTensor")
.field("shape", self.layout.shape())
.field("dtype", &self.dtype)
.field("contiguous", &self.layout.is_contiguous())
.field("unique", &self.is_unique())
.finish()
}
}
impl FlexTensor {
pub fn new(data: Bytes, layout: Layout, dtype: DType) -> Self {
Self {
data: Arc::new(data),
layout,
dtype,
}
}
pub fn from_data(data: TensorData) -> Self {
let shape = data.shape.clone();
let layout = Layout::contiguous(shape);
let dtype = data.dtype;
Self {
data: Arc::new(data.bytes),
layout,
dtype,
}
}
pub fn into_data(self) -> TensorData {
if self.layout.is_contiguous() && self.layout.start_offset() == 0 {
let expected_bytes = self.layout.num_elements() * dtype_size(self.dtype);
assert!(
expected_bytes <= self.data.len(),
"into_data: buffer ({} bytes) too small for {} elements of {:?}",
self.data.len(),
self.layout.num_elements(),
self.dtype
);
if self.data.len() == expected_bytes {
match Arc::try_unwrap(self.data) {
Ok(bytes) => TensorData {
bytes,
shape: self.layout.shape().clone(),
dtype: self.dtype,
},
Err(arc) => {
let bytes = Bytes::from_bytes_vec((*arc)[..expected_bytes].to_vec());
TensorData {
bytes,
shape: self.layout.shape().clone(),
dtype: self.dtype,
}
}
}
} else {
let bytes = Bytes::from_bytes_vec(self.data[..expected_bytes].to_vec());
TensorData {
bytes,
shape: self.layout.shape().clone(),
dtype: self.dtype,
}
}
} else {
self.to_contiguous().into_data()
}
}
#[inline]
pub fn is_unique(&self) -> bool {
Arc::strong_count(&self.data) == 1
}
pub fn layout(&self) -> &Layout {
&self.layout
}
pub fn with_layout(self, layout: Layout) -> Self {
Self {
data: self.data,
layout,
dtype: self.dtype,
}
}
pub fn dtype(&self) -> DType {
self.dtype
}
pub fn is_contiguous(&self) -> bool {
self.layout.is_contiguous()
}
pub fn bytes(&self) -> &[u8] {
&self.data
}
pub fn data_arc(&self) -> Arc<Bytes> {
Arc::clone(&self.data)
}
pub fn from_arc(data: Arc<Bytes>, layout: Layout, dtype: DType) -> Self {
Self {
data,
layout,
dtype,
}
}
pub fn storage<E: Element + bytemuck::Pod>(&self) -> &[E] {
assert!(
E::dtype() == self.dtype
|| (matches!(
self.dtype,
DType::Bool(burn_std::BoolStore::Native | burn_std::BoolStore::U8)
) && E::dtype() == DType::U8),
"storage: dtype mismatch (expected {:?}, got {:?})",
self.dtype,
E::dtype()
);
bytemuck::cast_slice(&self.data)
}
pub fn storage_mut<E: Element + bytemuck::Pod>(&mut self) -> &mut [E] {
assert!(
E::dtype() == self.dtype
|| (matches!(
self.dtype,
DType::Bool(burn_std::BoolStore::Native | burn_std::BoolStore::U8)
) && E::dtype() == DType::U8),
"storage_mut: dtype mismatch (expected {:?}, got {:?})",
self.dtype,
E::dtype()
);
let bytes = Arc::make_mut(&mut self.data);
bytemuck::cast_slice_mut(bytes)
}
pub fn try_storage_mut<E: Element + bytemuck::Pod>(&mut self) -> Option<&mut [E]> {
assert!(
E::dtype() == self.dtype
|| (matches!(
self.dtype,
DType::Bool(burn_std::BoolStore::Native | burn_std::BoolStore::U8)
) && E::dtype() == DType::U8),
"try_storage_mut: dtype mismatch (expected {:?}, got {:?})",
self.dtype,
E::dtype()
);
if self.is_unique() {
let bytes = Arc::get_mut(&mut self.data)?;
Some(bytemuck::cast_slice_mut(bytes))
} else {
None
}
}
pub fn as_slice<E: Element + bytemuck::Pod>(&self) -> Option<&[E]> {
if E::dtype() != self.dtype {
return None;
}
let storage: &[E] = self.storage();
self.layout
.contiguous_offsets()
.map(|(start, end)| &storage[start..end])
}
pub fn empty(shape: Shape, dtype: DType) -> Self {
let num_elements = shape.num_elements();
let elem_size = dtype_size(dtype);
let bytes = Bytes::from_bytes_vec(alloc::vec![0u8; num_elements * elem_size]);
let layout = Layout::contiguous(shape);
Self {
data: Arc::new(bytes),
layout,
dtype,
}
}
pub fn zeros(shape: Shape, dtype: DType) -> Self {
Self::empty(shape, dtype)
}
pub fn filled_typed<E: bytemuck::Pod + Send + Sync>(
shape: Shape,
dtype: DType,
value: E,
) -> Self {
assert_eq!(
dtype_size(dtype),
core::mem::size_of::<E>(),
"filled_typed: dtype size mismatch"
);
let n = shape.num_elements();
let data = alloc::vec![value; n];
let bytes = Bytes::from_elems(data);
Self {
data: Arc::new(bytes),
layout: Layout::contiguous(shape),
dtype,
}
}
pub fn to_contiguous(&self) -> Self {
if self.is_contiguous()
&& self.layout.start_offset() == 0
&& self.data.len() == self.layout.num_elements() * dtype_size(self.dtype)
{
return self.clone();
}
match self.dtype {
DType::F64 => self.copy_contiguous::<f64>(),
DType::F32 => self.copy_contiguous::<f32>(),
DType::F16 => self.copy_contiguous::<f16>(),
DType::BF16 => self.copy_contiguous::<bf16>(),
DType::I64 => self.copy_contiguous::<i64>(),
DType::I32 => self.copy_contiguous::<i32>(),
DType::I16 => self.copy_contiguous::<i16>(),
DType::I8 => self.copy_contiguous::<i8>(),
DType::U64 => self.copy_contiguous::<u64>(),
DType::U32 => self.copy_contiguous::<u32>(),
DType::U16 => self.copy_contiguous::<u16>(),
DType::U8 => self.copy_contiguous::<u8>(),
DType::Bool(burn_std::BoolStore::Native | burn_std::BoolStore::U8) => {
self.copy_contiguous::<u8>()
}
DType::Bool(burn_std::BoolStore::U32) => {
panic!("burn-flex: Bool(U32) storage is not yet supported")
}
_ => panic!("Unsupported dtype for contiguous copy: {:?}", self.dtype),
}
}
fn copy_contiguous<E: Element + bytemuck::Pod>(&self) -> Self {
let src: &[E] = bytemuck::cast_slice(&self.data);
let n = self.layout.num_elements();
let mut dst = Vec::with_capacity(n);
let collapsed = collapse_for_copy(self.layout.shape(), self.layout.strides());
let (shape, strides) = collapsed.as_slices();
let offset = self.layout.start_offset() as isize;
let all_positive = strides.iter().all(|&s| s >= 0);
if shape.len() <= 1 && all_positive {
let collapsed_numel = if shape.is_empty() { 1 } else { shape[0] };
debug_assert_eq!(n, collapsed_numel);
unsafe { dst.set_len(n) };
if shape.is_empty() {
if n > 0 {
dst[0] = src[offset as usize];
}
} else {
let len = shape[0];
let stride = strides[0];
if stride == 1 {
dst[..len].copy_from_slice(&src[offset as usize..offset as usize + len]);
} else {
for (i, slot) in dst.iter_mut().take(len).enumerate() {
let idx = (offset + i as isize * stride) as usize;
*slot = src[idx];
}
}
}
} else if shape.len() == 2 && all_positive {
debug_assert_eq!(shape[0] * shape[1], n, "2D strides must cover all elements");
unsafe { dst.set_len(n) };
copy_2d_tiled(
&mut dst, src, offset, shape[0], shape[1], strides[0], strides[1],
);
} else {
for idx in crate::strided_index::StridedIter::new(&self.layout) {
dst.push(src[idx]);
}
}
let bytes = Bytes::from_elems(dst);
let layout = Layout::contiguous(self.layout.shape().clone());
Self {
data: Arc::new(bytes),
layout,
dtype: self.dtype,
}
}
pub fn reshape(&self, new_shape: Shape) -> Self {
assert_eq!(
self.layout.num_elements(),
new_shape.num_elements(),
"reshape must preserve total elements"
);
if let Some(new_layout) = self.layout.reshape(new_shape.clone()) {
Self {
data: Arc::clone(&self.data),
layout: new_layout,
dtype: self.dtype,
}
} else {
self.to_contiguous().reshape(new_shape)
}
}
pub fn transpose(&self, dim1: usize, dim2: usize) -> Self {
Self {
data: Arc::clone(&self.data),
layout: self.layout.transpose(dim1, dim2),
dtype: self.dtype,
}
}
pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Self {
Self {
data: Arc::clone(&self.data),
layout: self.layout.narrow(dim, start, len),
dtype: self.dtype,
}
}
pub fn permute(&self, axes: &[usize]) -> Self {
Self {
data: Arc::clone(&self.data),
layout: self.layout.permute(axes),
dtype: self.dtype,
}
}
}
impl TensorMetadata for FlexTensor {
fn dtype(&self) -> DType {
self.dtype
}
fn shape(&self) -> Shape {
self.layout.shape().clone()
}
fn rank(&self) -> usize {
self.layout.num_dims()
}
}
const COLLAPSE_MAX_RANK: usize = 8;
#[derive(Debug, Clone, Copy)]
struct CollapsedLayout {
ndim: usize,
shape: [usize; COLLAPSE_MAX_RANK],
strides: [isize; COLLAPSE_MAX_RANK],
}
impl CollapsedLayout {
#[inline]
fn as_slices(&self) -> (&[usize], &[isize]) {
(&self.shape[..self.ndim], &self.strides[..self.ndim])
}
}
fn collapse_for_copy(shape: &[usize], strides: &[isize]) -> CollapsedLayout {
let mut out = CollapsedLayout {
ndim: 0,
shape: [0; COLLAPSE_MAX_RANK],
strides: [0; COLLAPSE_MAX_RANK],
};
if shape.len() > COLLAPSE_MAX_RANK {
out.ndim = shape.len().min(COLLAPSE_MAX_RANK);
return out;
}
for (&s, &st) in shape.iter().zip(strides.iter()) {
if s == 1 {
continue;
}
let merge = out.ndim > 0
&& (s as isize)
.checked_mul(st)
.is_some_and(|run| out.strides[out.ndim - 1] == run);
if merge {
out.shape[out.ndim - 1] *= s;
out.strides[out.ndim - 1] = st;
} else {
out.shape[out.ndim] = s;
out.strides[out.ndim] = st;
out.ndim += 1;
}
}
out
}
#[inline]
fn copy_2d_tiled<E: Copy>(
dst: &mut [E],
src: &[E],
offset: isize,
rows: usize,
cols: usize,
row_stride: isize,
col_stride: isize,
) {
const TILE: usize = 16;
if row_stride <= col_stride {
for col_tile in (0..cols).step_by(TILE) {
let col_end = (col_tile + TILE).min(cols);
for row_tile in (0..rows).step_by(TILE) {
let row_end = (row_tile + TILE).min(rows);
for col in col_tile..col_end {
let col_base = offset + col as isize * col_stride;
for row in row_tile..row_end {
let idx = (col_base + row as isize * row_stride) as usize;
unsafe {
*dst.get_unchecked_mut(row * cols + col) = src[idx];
}
}
}
}
}
} else {
for row_tile in (0..rows).step_by(TILE) {
let row_end = (row_tile + TILE).min(rows);
for col_tile in (0..cols).step_by(TILE) {
let col_end = (col_tile + TILE).min(cols);
for row in row_tile..row_end {
let row_base =
offset + row as isize * row_stride + col_tile as isize * col_stride;
let dst_base = row * cols + col_tile;
for c in 0..(col_end - col_tile) {
let idx = (row_base + c as isize * col_stride) as usize;
unsafe {
*dst.get_unchecked_mut(dst_base + c) = src[idx];
}
}
}
}
}
}
}
pub(crate) fn dtype_size(dtype: DType) -> usize {
let size = dtype.size();
assert!(
size > 0,
"burn-flex: dtype {:?} has zero-byte element size (sub-byte packed \
quantization is not yet supported)",
dtype
);
size
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
#[test]
fn test_from_data_roundtrip() {
let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]);
let tensor = FlexTensor::from_data(data.clone());
let result = tensor.into_data();
assert_eq!(data.shape, result.shape);
assert_eq!(data.dtype, result.dtype);
}
#[test]
fn test_collapse_for_copy_squeezes_size1_and_merges_contig() {
let shape = vec![1, 244, 224, 48];
let strides = vec![2_623_488_isize, 224, 1, 54656];
let collapsed = collapse_for_copy(&shape, &strides);
let (s, st) = collapsed.as_slices();
assert_eq!(s, &[54656, 48]);
assert_eq!(st, &[1, 54656]);
}
#[test]
fn test_collapse_for_copy_already_contiguous_3d() {
let collapsed = collapse_for_copy(&[2, 3, 4], &[12, 4, 1]);
let (s, st) = collapsed.as_slices();
assert_eq!(s, &[24]);
assert_eq!(st, &[1]);
}
#[test]
fn test_collapse_for_copy_transpose_2d() {
let collapsed = collapse_for_copy(&[5, 3], &[1, 5]);
let (s, st) = collapsed.as_slices();
assert_eq!(s, &[5, 3]);
assert_eq!(st, &[1, 5]);
}
#[test]
fn test_collapse_for_copy_all_size1() {
let collapsed = collapse_for_copy(&[1, 1, 1], &[0, 0, 0]);
let (s, st) = collapsed.as_slices();
assert!(s.is_empty());
assert!(st.is_empty());
}
#[test]
fn test_to_contiguous_zero_sized_narrowed() {
let t = FlexTensor::from_data(TensorData::new(
(0..6).map(|i| i as f32).collect::<Vec<_>>(),
vec![6],
));
let empty_view = t.narrow(0, 3, 0);
assert_eq!(empty_view.shape().to_vec(), vec![0]);
assert_ne!(empty_view.layout().start_offset(), 0);
let contig = empty_view.to_contiguous();
assert_eq!(contig.shape().to_vec(), vec![0]);
assert_eq!(contig.layout().start_offset(), 0);
assert_eq!(contig.into_data().bytes.len(), 0);
}
#[test]
fn test_to_contiguous_prefix_view_shrinks_buffer() {
let data: Vec<f32> = (0..40).map(|i| i as f32).collect();
let t = FlexTensor::from_data(TensorData::new(data, vec![8, 5]));
let prefix = t.narrow(0, 0, 5);
assert_eq!(prefix.shape().to_vec(), vec![5, 5]);
assert_eq!(prefix.layout().strides(), &[5, 1]);
assert_eq!(prefix.layout().start_offset(), 0);
assert!(prefix.is_contiguous());
assert_eq!(prefix.storage::<f32>().len(), 40);
let contig = prefix.to_contiguous();
assert_eq!(contig.storage::<f32>().len(), 25);
assert_eq!(contig.layout().num_elements(), 25);
assert_eq!(
contig.storage::<f32>(),
&(0..5)
.flat_map(|r| (0..5).map(move |c| (r * 5 + c) as f32))
.collect::<Vec<_>>()[..]
);
}
#[test]
fn test_to_contiguous_4d_permuted_matches_naive() {
let dims = [1, 48, 4, 5];
let n: usize = dims.iter().product();
let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
let t = FlexTensor::from_data(TensorData::new(data.clone(), dims.to_vec()));
let permuted = t.permute(&[0, 2, 3, 1]);
assert!(!permuted.is_contiguous());
let contig = permuted.to_contiguous();
assert!(contig.is_contiguous());
assert_eq!(contig.shape().to_vec(), vec![1, 4, 5, 48]);
let mut expected = Vec::with_capacity(n);
for h in 0..4 {
for w in 0..5 {
for c in 0..48 {
let idx = c * 20 + h * 5 + w;
expected.push(data[idx]);
}
}
}
let result_data = contig.into_data();
let values = result_data.as_slice::<f32>().unwrap();
assert_eq!(values, expected.as_slice());
}
#[test]
fn test_to_contiguous_2d_row_stride_gt_col_stride() {
let data: Vec<f32> = (0..18).map(|i| i as f32).collect();
let t = FlexTensor::from_data(TensorData::new(data, vec![6, 3]));
let stepped = crate::ops::slice::slice(
t,
&[
burn_std::Slice::new(0, Some(6), 2),
burn_std::Slice::new(0, None, 1),
],
);
assert_eq!(stepped.layout().shape().to_vec(), vec![3, 3]);
assert_eq!(stepped.layout().strides(), &[6, 1]);
assert!(!stepped.layout().is_contiguous());
let contig = stepped.to_contiguous();
assert!(contig.is_contiguous());
assert_eq!(contig.shape().to_vec(), vec![3, 3]);
let result_data = contig.into_data();
let values = result_data.as_slice::<f32>().unwrap();
let expected = vec![
0.0f32, 1.0, 2.0, 6.0, 7.0, 8.0, 12.0, 13.0, 14.0, ];
assert_eq!(values, expected.as_slice());
}
#[test]
fn test_reshape() {
let data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
let tensor = FlexTensor::from_data(data);
let reshaped = tensor.reshape(Shape::from(vec![3, 2]));
assert_eq!(reshaped.shape().to_vec(), vec![3, 2]);
}
#[test]
fn test_transpose() {
let data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
let tensor = FlexTensor::from_data(data);
let transposed = tensor.transpose(0, 1);
assert_eq!(transposed.shape().to_vec(), vec![3, 2]);
assert!(!transposed.is_contiguous());
}
#[test]
fn test_clone_is_cheap() {
let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0]);
let tensor = FlexTensor::from_data(data);
assert!(tensor.is_unique());
let cloned = tensor.clone();
assert!(!tensor.is_unique());
assert!(!cloned.is_unique());
assert!(core::ptr::eq(
tensor.bytes().as_ptr(),
cloned.bytes().as_ptr()
));
}
#[test]
fn test_cow_on_mutation() {
let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0]);
let tensor = FlexTensor::from_data(data);
let mut cloned = tensor.clone();
assert!(!tensor.is_unique());
assert!(!cloned.is_unique());
let storage: &mut [f32] = cloned.storage_mut();
storage[0] = 99.0;
assert!(tensor.is_unique());
assert!(cloned.is_unique());
assert_ne!(tensor.bytes().as_ptr(), cloned.bytes().as_ptr());
assert_eq!(tensor.storage::<f32>()[0], 1.0);
assert_eq!(cloned.storage::<f32>()[0], 99.0);
}
#[test]
fn test_into_data_narrowed_at_offset_zero() {
let data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
let tensor = FlexTensor::from_data(data);
let narrowed = tensor.narrow(0, 0, 1);
assert!(narrowed.is_contiguous());
assert_eq!(narrowed.layout().start_offset(), 0);
let result = narrowed.into_data();
assert_eq!(result.shape.to_vec(), vec![1, 3]);
assert_eq!(result.bytes.len(), 3 * core::mem::size_of::<f32>());
let values: Vec<f32> = result.to_vec().unwrap();
assert_eq!(values, vec![1.0, 2.0, 3.0]);
}
}