1#[cfg(target_has_atomic = "ptr")]
2use alloc::sync::Arc;
3use alloc::vec::Vec;
4use core::fmt;
5#[cfg(not(target_has_atomic = "ptr"))]
6use portable_atomic_util::Arc;
7
8use burn_backend::{DType, Element, TensorData, TensorMetadata};
9use burn_std::{Bytes, Shape, bf16, f16};
10
11use crate::layout::Layout;
12
13#[derive(Clone)]
18pub struct FlexTensor {
19 data: Arc<Bytes>,
21 layout: Layout,
23 dtype: DType,
25}
26
27impl fmt::Debug for FlexTensor {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 f.debug_struct("FlexTensor")
30 .field("shape", self.layout.shape())
31 .field("dtype", &self.dtype)
32 .field("contiguous", &self.layout.is_contiguous())
33 .field("unique", &self.is_unique())
34 .finish()
35 }
36}
37
38impl FlexTensor {
39 pub fn new(data: Bytes, layout: Layout, dtype: DType) -> Self {
41 Self {
42 data: Arc::new(data),
43 layout,
44 dtype,
45 }
46 }
47
48 pub fn from_data(data: TensorData) -> Self {
50 let shape = data.shape.clone();
51 let layout = Layout::contiguous(shape);
52 let dtype = data.dtype;
53 Self {
54 data: Arc::new(data.bytes),
55 layout,
56 dtype,
57 }
58 }
59
60 pub fn into_data(self) -> TensorData {
64 if self.layout.is_contiguous() && self.layout.start_offset() == 0 {
65 let expected_bytes = self.layout.num_elements() * dtype_size(self.dtype);
66 assert!(
67 expected_bytes <= self.data.len(),
68 "into_data: buffer ({} bytes) too small for {} elements of {:?}",
69 self.data.len(),
70 self.layout.num_elements(),
71 self.dtype
72 );
73 if self.data.len() == expected_bytes {
74 match Arc::try_unwrap(self.data) {
76 Ok(bytes) => TensorData {
77 bytes,
78 shape: self.layout.shape().clone(),
79 dtype: self.dtype,
80 },
81 Err(arc) => {
82 let bytes = Bytes::from_bytes_vec((*arc)[..expected_bytes].to_vec());
83 TensorData {
84 bytes,
85 shape: self.layout.shape().clone(),
86 dtype: self.dtype,
87 }
88 }
89 }
90 } else {
91 let bytes = Bytes::from_bytes_vec(self.data[..expected_bytes].to_vec());
94 TensorData {
95 bytes,
96 shape: self.layout.shape().clone(),
97 dtype: self.dtype,
98 }
99 }
100 } else {
101 self.to_contiguous().into_data()
103 }
104 }
105
106 #[inline]
110 pub fn is_unique(&self) -> bool {
111 Arc::strong_count(&self.data) == 1
112 }
113
114 pub fn layout(&self) -> &Layout {
116 &self.layout
117 }
118
119 pub fn with_layout(self, layout: Layout) -> Self {
123 Self {
124 data: self.data,
125 layout,
126 dtype: self.dtype,
127 }
128 }
129
130 pub fn dtype(&self) -> DType {
132 self.dtype
133 }
134
135 pub fn is_contiguous(&self) -> bool {
137 self.layout.is_contiguous()
138 }
139
140 pub fn bytes(&self) -> &[u8] {
142 &self.data
143 }
144
145 pub fn data_arc(&self) -> Arc<Bytes> {
149 Arc::clone(&self.data)
150 }
151
152 pub fn from_arc(data: Arc<Bytes>, layout: Layout, dtype: DType) -> Self {
156 Self {
157 data,
158 layout,
159 dtype,
160 }
161 }
162
163 pub fn storage<E: Element + bytemuck::Pod>(&self) -> &[E] {
173 assert!(
174 E::dtype() == self.dtype
175 || (matches!(
176 self.dtype,
177 DType::Bool(burn_std::BoolStore::Native | burn_std::BoolStore::U8)
178 ) && E::dtype() == DType::U8),
179 "storage: dtype mismatch (expected {:?}, got {:?})",
180 self.dtype,
181 E::dtype()
182 );
183 bytemuck::cast_slice(&self.data)
184 }
185
186 pub fn storage_mut<E: Element + bytemuck::Pod>(&mut self) -> &mut [E] {
197 assert!(
198 E::dtype() == self.dtype
199 || (matches!(
200 self.dtype,
201 DType::Bool(burn_std::BoolStore::Native | burn_std::BoolStore::U8)
202 ) && E::dtype() == DType::U8),
203 "storage_mut: dtype mismatch (expected {:?}, got {:?})",
204 self.dtype,
205 E::dtype()
206 );
207 let bytes = Arc::make_mut(&mut self.data);
209 bytemuck::cast_slice_mut(bytes)
210 }
211
212 pub fn try_storage_mut<E: Element + bytemuck::Pod>(&mut self) -> Option<&mut [E]> {
219 assert!(
220 E::dtype() == self.dtype
221 || (matches!(
222 self.dtype,
223 DType::Bool(burn_std::BoolStore::Native | burn_std::BoolStore::U8)
224 ) && E::dtype() == DType::U8),
225 "try_storage_mut: dtype mismatch (expected {:?}, got {:?})",
226 self.dtype,
227 E::dtype()
228 );
229 if self.is_unique() {
230 let bytes = Arc::get_mut(&mut self.data)?;
232 Some(bytemuck::cast_slice_mut(bytes))
233 } else {
234 None
235 }
236 }
237
238 pub fn as_slice<E: Element + bytemuck::Pod>(&self) -> Option<&[E]> {
242 if E::dtype() != self.dtype {
243 return None;
244 }
245 let storage: &[E] = self.storage();
246 self.layout
247 .contiguous_offsets()
248 .map(|(start, end)| &storage[start..end])
249 }
250
251 pub fn empty(shape: Shape, dtype: DType) -> Self {
253 let num_elements = shape.num_elements();
254 let elem_size = dtype_size(dtype);
255 let bytes = Bytes::from_bytes_vec(alloc::vec![0u8; num_elements * elem_size]);
256 let layout = Layout::contiguous(shape);
257 Self {
258 data: Arc::new(bytes),
259 layout,
260 dtype,
261 }
262 }
263
264 pub fn zeros(shape: Shape, dtype: DType) -> Self {
266 Self::empty(shape, dtype)
267 }
268
269 pub fn filled_typed<E: bytemuck::Pod + Send + Sync>(
271 shape: Shape,
272 dtype: DType,
273 value: E,
274 ) -> Self {
275 assert_eq!(
276 dtype_size(dtype),
277 core::mem::size_of::<E>(),
278 "filled_typed: dtype size mismatch"
279 );
280 let n = shape.num_elements();
281 let data = alloc::vec![value; n];
282 let bytes = Bytes::from_elems(data);
283 Self {
284 data: Arc::new(bytes),
285 layout: Layout::contiguous(shape),
286 dtype,
287 }
288 }
289
290 pub fn to_contiguous(&self) -> Self {
292 if self.is_contiguous()
298 && self.layout.start_offset() == 0
299 && self.data.len() == self.layout.num_elements() * dtype_size(self.dtype)
300 {
301 return self.clone();
302 }
303
304 match self.dtype {
306 DType::F64 => self.copy_contiguous::<f64>(),
307 DType::F32 => self.copy_contiguous::<f32>(),
308 DType::F16 => self.copy_contiguous::<f16>(),
309 DType::BF16 => self.copy_contiguous::<bf16>(),
310 DType::I64 => self.copy_contiguous::<i64>(),
311 DType::I32 => self.copy_contiguous::<i32>(),
312 DType::I16 => self.copy_contiguous::<i16>(),
313 DType::I8 => self.copy_contiguous::<i8>(),
314 DType::U64 => self.copy_contiguous::<u64>(),
315 DType::U32 => self.copy_contiguous::<u32>(),
316 DType::U16 => self.copy_contiguous::<u16>(),
317 DType::U8 => self.copy_contiguous::<u8>(),
318 DType::Bool(burn_std::BoolStore::Native | burn_std::BoolStore::U8) => {
319 self.copy_contiguous::<u8>()
320 }
321 DType::Bool(burn_std::BoolStore::U32) => {
322 panic!("burn-flex: Bool(U32) storage is not yet supported")
323 }
324 _ => panic!("Unsupported dtype for contiguous copy: {:?}", self.dtype),
325 }
326 }
327
328 fn copy_contiguous<E: Element + bytemuck::Pod>(&self) -> Self {
329 let src: &[E] = bytemuck::cast_slice(&self.data);
330 let n = self.layout.num_elements();
331 let mut dst = Vec::with_capacity(n);
332
333 let collapsed = collapse_for_copy(self.layout.shape(), self.layout.strides());
339 let (shape, strides) = collapsed.as_slices();
340 let offset = self.layout.start_offset() as isize;
341 let all_positive = strides.iter().all(|&s| s >= 0);
342
343 if shape.len() <= 1 && all_positive {
344 let collapsed_numel = if shape.is_empty() { 1 } else { shape[0] };
349 debug_assert_eq!(n, collapsed_numel);
350 unsafe { dst.set_len(n) };
352 if shape.is_empty() {
353 if n > 0 {
354 dst[0] = src[offset as usize];
355 }
356 } else {
357 let len = shape[0];
358 let stride = strides[0];
359 if stride == 1 {
360 dst[..len].copy_from_slice(&src[offset as usize..offset as usize + len]);
361 } else {
362 for (i, slot) in dst.iter_mut().take(len).enumerate() {
363 let idx = (offset + i as isize * stride) as usize;
364 *slot = src[idx];
365 }
366 }
367 }
368 } else if shape.len() == 2 && all_positive {
369 debug_assert_eq!(shape[0] * shape[1], n, "2D strides must cover all elements");
374 unsafe { dst.set_len(n) };
377 copy_2d_tiled(
378 &mut dst, src, offset, shape[0], shape[1], strides[0], strides[1],
379 );
380 } else {
381 for idx in crate::strided_index::StridedIter::new(&self.layout) {
384 dst.push(src[idx]);
385 }
386 }
387
388 let bytes = Bytes::from_elems(dst);
389 let layout = Layout::contiguous(self.layout.shape().clone());
390 Self {
391 data: Arc::new(bytes),
392 layout,
393 dtype: self.dtype,
394 }
395 }
396
397 pub fn reshape(&self, new_shape: Shape) -> Self {
399 assert_eq!(
400 self.layout.num_elements(),
401 new_shape.num_elements(),
402 "reshape must preserve total elements"
403 );
404
405 if let Some(new_layout) = self.layout.reshape(new_shape.clone()) {
406 Self {
407 data: Arc::clone(&self.data),
408 layout: new_layout,
409 dtype: self.dtype,
410 }
411 } else {
412 self.to_contiguous().reshape(new_shape)
414 }
415 }
416
417 pub fn transpose(&self, dim1: usize, dim2: usize) -> Self {
419 Self {
420 data: Arc::clone(&self.data),
421 layout: self.layout.transpose(dim1, dim2),
422 dtype: self.dtype,
423 }
424 }
425
426 pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Self {
428 Self {
429 data: Arc::clone(&self.data),
430 layout: self.layout.narrow(dim, start, len),
431 dtype: self.dtype,
432 }
433 }
434
435 pub fn permute(&self, axes: &[usize]) -> Self {
437 Self {
438 data: Arc::clone(&self.data),
439 layout: self.layout.permute(axes),
440 dtype: self.dtype,
441 }
442 }
443}
444
445impl TensorMetadata for FlexTensor {
446 fn dtype(&self) -> DType {
447 self.dtype
448 }
449
450 fn shape(&self) -> Shape {
451 self.layout.shape().clone()
452 }
453
454 fn rank(&self) -> usize {
455 self.layout.num_dims()
456 }
457}
458
459const COLLAPSE_MAX_RANK: usize = 8;
462
463#[derive(Debug, Clone, Copy)]
467struct CollapsedLayout {
468 ndim: usize,
469 shape: [usize; COLLAPSE_MAX_RANK],
470 strides: [isize; COLLAPSE_MAX_RANK],
471}
472
473impl CollapsedLayout {
474 #[inline]
475 fn as_slices(&self) -> (&[usize], &[isize]) {
476 (&self.shape[..self.ndim], &self.strides[..self.ndim])
477 }
478}
479
480fn collapse_for_copy(shape: &[usize], strides: &[isize]) -> CollapsedLayout {
504 let mut out = CollapsedLayout {
505 ndim: 0,
506 shape: [0; COLLAPSE_MAX_RANK],
507 strides: [0; COLLAPSE_MAX_RANK],
508 };
509
510 if shape.len() > COLLAPSE_MAX_RANK {
515 out.ndim = shape.len().min(COLLAPSE_MAX_RANK);
516 return out;
517 }
518
519 for (&s, &st) in shape.iter().zip(strides.iter()) {
529 if s == 1 {
530 continue;
531 }
532 let merge = out.ndim > 0
533 && (s as isize)
534 .checked_mul(st)
535 .is_some_and(|run| out.strides[out.ndim - 1] == run);
536 if merge {
537 out.shape[out.ndim - 1] *= s;
538 out.strides[out.ndim - 1] = st;
539 } else {
540 out.shape[out.ndim] = s;
541 out.strides[out.ndim] = st;
542 out.ndim += 1;
543 }
544 }
545
546 out
547}
548
549#[inline]
554fn copy_2d_tiled<E: Copy>(
555 dst: &mut [E],
556 src: &[E],
557 offset: isize,
558 rows: usize,
559 cols: usize,
560 row_stride: isize,
561 col_stride: isize,
562) {
563 const TILE: usize = 16;
564
565 if row_stride <= col_stride {
566 for col_tile in (0..cols).step_by(TILE) {
568 let col_end = (col_tile + TILE).min(cols);
569 for row_tile in (0..rows).step_by(TILE) {
570 let row_end = (row_tile + TILE).min(rows);
571 for col in col_tile..col_end {
572 let col_base = offset + col as isize * col_stride;
573 for row in row_tile..row_end {
574 let idx = (col_base + row as isize * row_stride) as usize;
575 unsafe {
578 *dst.get_unchecked_mut(row * cols + col) = src[idx];
579 }
580 }
581 }
582 }
583 }
584 } else {
585 for row_tile in (0..rows).step_by(TILE) {
587 let row_end = (row_tile + TILE).min(rows);
588 for col_tile in (0..cols).step_by(TILE) {
589 let col_end = (col_tile + TILE).min(cols);
590 for row in row_tile..row_end {
591 let row_base =
592 offset + row as isize * row_stride + col_tile as isize * col_stride;
593 let dst_base = row * cols + col_tile;
594 for c in 0..(col_end - col_tile) {
595 let idx = (row_base + c as isize * col_stride) as usize;
596 unsafe {
598 *dst.get_unchecked_mut(dst_base + c) = src[idx];
599 }
600 }
601 }
602 }
603 }
604 }
605}
606
607pub(crate) fn dtype_size(dtype: DType) -> usize {
623 let size = dtype.size();
625 assert!(
626 size > 0,
627 "burn-flex: dtype {:?} has zero-byte element size (sub-byte packed \
628 quantization is not yet supported)",
629 dtype
630 );
631 size
632}
633
634#[cfg(test)]
635mod tests {
636 use super::*;
637 use alloc::vec;
638
639 #[test]
640 fn test_from_data_roundtrip() {
641 let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]);
642 let tensor = FlexTensor::from_data(data.clone());
643 let result = tensor.into_data();
644 assert_eq!(data.shape, result.shape);
645 assert_eq!(data.dtype, result.dtype);
646 }
647
648 #[test]
649 fn test_collapse_for_copy_squeezes_size1_and_merges_contig() {
650 let shape = vec![1, 244, 224, 48];
652 let strides = vec![2_623_488_isize, 224, 1, 54656];
653 let collapsed = collapse_for_copy(&shape, &strides);
654 let (s, st) = collapsed.as_slices();
655 assert_eq!(s, &[54656, 48]);
656 assert_eq!(st, &[1, 54656]);
657 }
658
659 #[test]
660 fn test_collapse_for_copy_already_contiguous_3d() {
661 let collapsed = collapse_for_copy(&[2, 3, 4], &[12, 4, 1]);
662 let (s, st) = collapsed.as_slices();
663 assert_eq!(s, &[24]);
664 assert_eq!(st, &[1]);
665 }
666
667 #[test]
668 fn test_collapse_for_copy_transpose_2d() {
669 let collapsed = collapse_for_copy(&[5, 3], &[1, 5]);
670 let (s, st) = collapsed.as_slices();
671 assert_eq!(s, &[5, 3]);
672 assert_eq!(st, &[1, 5]);
673 }
674
675 #[test]
676 fn test_collapse_for_copy_all_size1() {
677 let collapsed = collapse_for_copy(&[1, 1, 1], &[0, 0, 0]);
678 let (s, st) = collapsed.as_slices();
679 assert!(s.is_empty());
680 assert!(st.is_empty());
681 }
682
683 #[test]
690 fn test_to_contiguous_zero_sized_narrowed() {
691 let t = FlexTensor::from_data(TensorData::new(
692 (0..6).map(|i| i as f32).collect::<Vec<_>>(),
693 vec![6],
694 ));
695 let empty_view = t.narrow(0, 3, 0);
697 assert_eq!(empty_view.shape().to_vec(), vec![0]);
698 assert_ne!(empty_view.layout().start_offset(), 0);
699
700 let contig = empty_view.to_contiguous();
701 assert_eq!(contig.shape().to_vec(), vec![0]);
702 assert_eq!(contig.layout().start_offset(), 0);
703 assert_eq!(contig.into_data().bytes.len(), 0);
704 }
705
706 #[test]
713 fn test_to_contiguous_prefix_view_shrinks_buffer() {
714 let data: Vec<f32> = (0..40).map(|i| i as f32).collect();
715 let t = FlexTensor::from_data(TensorData::new(data, vec![8, 5]));
716
717 let prefix = t.narrow(0, 0, 5);
718 assert_eq!(prefix.shape().to_vec(), vec![5, 5]);
719 assert_eq!(prefix.layout().strides(), &[5, 1]);
720 assert_eq!(prefix.layout().start_offset(), 0);
721 assert!(prefix.is_contiguous());
722 assert_eq!(prefix.storage::<f32>().len(), 40);
723
724 let contig = prefix.to_contiguous();
725 assert_eq!(contig.storage::<f32>().len(), 25);
726 assert_eq!(contig.layout().num_elements(), 25);
727 assert_eq!(
728 contig.storage::<f32>(),
729 &(0..5)
730 .flat_map(|r| (0..5).map(move |c| (r * 5 + c) as f32))
731 .collect::<Vec<_>>()[..]
732 );
733 }
734
735 #[test]
738 fn test_to_contiguous_4d_permuted_matches_naive() {
739 let dims = [1, 48, 4, 5];
740 let n: usize = dims.iter().product();
741 let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
742 let t = FlexTensor::from_data(TensorData::new(data.clone(), dims.to_vec()));
743 let permuted = t.permute(&[0, 2, 3, 1]);
744 assert!(!permuted.is_contiguous());
745
746 let contig = permuted.to_contiguous();
747 assert!(contig.is_contiguous());
748 assert_eq!(contig.shape().to_vec(), vec![1, 4, 5, 48]);
749
750 let mut expected = Vec::with_capacity(n);
752 for h in 0..4 {
753 for w in 0..5 {
754 for c in 0..48 {
755 let idx = c * 20 + h * 5 + w;
756 expected.push(data[idx]);
757 }
758 }
759 }
760
761 let result_data = contig.into_data();
762 let values = result_data.as_slice::<f32>().unwrap();
763 assert_eq!(values, expected.as_slice());
764 }
765
766 #[test]
769 fn test_to_contiguous_2d_row_stride_gt_col_stride() {
770 let data: Vec<f32> = (0..18).map(|i| i as f32).collect();
774 let t = FlexTensor::from_data(TensorData::new(data, vec![6, 3]));
775 let stepped = crate::ops::slice::slice(
776 t,
777 &[
778 burn_std::Slice::new(0, Some(6), 2),
779 burn_std::Slice::new(0, None, 1),
780 ],
781 );
782 assert_eq!(stepped.layout().shape().to_vec(), vec![3, 3]);
784 assert_eq!(stepped.layout().strides(), &[6, 1]);
785 assert!(!stepped.layout().is_contiguous());
786
787 let contig = stepped.to_contiguous();
788 assert!(contig.is_contiguous());
789 assert_eq!(contig.shape().to_vec(), vec![3, 3]);
790
791 let result_data = contig.into_data();
792 let values = result_data.as_slice::<f32>().unwrap();
793 let expected = vec![
795 0.0f32, 1.0, 2.0, 6.0, 7.0, 8.0, 12.0, 13.0, 14.0, ];
799 assert_eq!(values, expected.as_slice());
800 }
801
802 #[test]
803 fn test_reshape() {
804 let data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
805 let tensor = FlexTensor::from_data(data);
806 let reshaped = tensor.reshape(Shape::from(vec![3, 2]));
807 assert_eq!(reshaped.shape().to_vec(), vec![3, 2]);
808 }
809
810 #[test]
811 fn test_transpose() {
812 let data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
813 let tensor = FlexTensor::from_data(data);
814 let transposed = tensor.transpose(0, 1);
815 assert_eq!(transposed.shape().to_vec(), vec![3, 2]);
816 assert!(!transposed.is_contiguous());
817 }
818
819 #[test]
820 fn test_clone_is_cheap() {
821 let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0]);
822 let tensor = FlexTensor::from_data(data);
823
824 assert!(tensor.is_unique());
826
827 let cloned = tensor.clone();
829 assert!(!tensor.is_unique());
830 assert!(!cloned.is_unique());
831
832 assert!(core::ptr::eq(
834 tensor.bytes().as_ptr(),
835 cloned.bytes().as_ptr()
836 ));
837 }
838
839 #[test]
840 fn test_cow_on_mutation() {
841 let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0]);
842 let tensor = FlexTensor::from_data(data);
843 let mut cloned = tensor.clone();
844
845 assert!(!tensor.is_unique());
847 assert!(!cloned.is_unique());
848
849 let storage: &mut [f32] = cloned.storage_mut();
851 storage[0] = 99.0;
852
853 assert!(tensor.is_unique());
855 assert!(cloned.is_unique());
856
857 assert_ne!(tensor.bytes().as_ptr(), cloned.bytes().as_ptr());
859 assert_eq!(tensor.storage::<f32>()[0], 1.0);
860 assert_eq!(cloned.storage::<f32>()[0], 99.0);
861 }
862
863 #[test]
864 fn test_into_data_narrowed_at_offset_zero() {
865 let data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
867 let tensor = FlexTensor::from_data(data);
868 let narrowed = tensor.narrow(0, 0, 1);
870 assert!(narrowed.is_contiguous());
871 assert_eq!(narrowed.layout().start_offset(), 0);
872
873 let result = narrowed.into_data();
874 assert_eq!(result.shape.to_vec(), vec![1, 3]);
875 assert_eq!(result.bytes.len(), 3 * core::mem::size_of::<f32>());
877 let values: Vec<f32> = result.to_vec().unwrap();
878 assert_eq!(values, vec![1.0, 2.0, 3.0]);
879 }
880}