Skip to main content

rlmesh_spaces/
tensor.rs

1mod dlpack;
2#[cfg(test)]
3mod proptests;
4mod storage;
5
6pub use dlpack::{DLPackType, dlpack_type, dtype_from_dlpack};
7pub use storage::Storage;
8
9use std::borrow::Cow;
10
11use thiserror::Error;
12
13use crate::dtype::{DType, dtype_size};
14
15/// Device a tensor's storage lives on, mirroring DLPack's `DLDeviceType`.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17#[repr(i32)]
18#[non_exhaustive]
19pub enum Device {
20    /// Host CPU memory (`kDLCPU`).
21    Cpu = 1,
22}
23
24impl From<Device> for i32 {
25    fn from(value: Device) -> Self {
26        value as i32
27    }
28}
29
30/// Errors raised by [`Tensor`] constructors and transformations.
31#[derive(Error, Debug, Clone, PartialEq, Eq)]
32pub enum TensorError {
33    #[error("tensor dtype must be specified")]
34    UnspecifiedDtype,
35    #[error("negative dimension {0} in shape")]
36    NegativeDim(i64),
37    #[error("negative stride {0}")]
38    NegativeStride(i64),
39    #[error("strides rank {strides} does not match shape rank {shape}")]
40    StrideRankMismatch { strides: usize, shape: usize },
41    #[error("data is {actual} bytes but shape and dtype require exactly {expected}")]
42    ByteLengthMismatch { expected: usize, actual: usize },
43    #[error(
44        "view requires {required} bytes at byte offset {byte_offset} but storage holds {available}"
45    )]
46    OutOfBounds {
47        required: usize,
48        byte_offset: usize,
49        available: usize,
50    },
51    #[error("cannot reshape {from} elements into {to}")]
52    NumelMismatch { from: usize, to: usize },
53    #[error("reshape shape may contain at most one -1")]
54    AmbiguousReshape,
55    #[error("stack requires at least one tensor")]
56    EmptyStack,
57    #[error("stack requires uniform dtype and shape; tensor {index} differs")]
58    StackMismatch { index: usize },
59    #[error("cannot unstack a 0-dimensional tensor")]
60    UnstackScalar,
61    #[error("tensor size overflows usize")]
62    Overflow,
63}
64
65/// C-contiguous (row-major) strides for `shape`, in element units.
66pub fn contiguous_strides(shape: &[i64]) -> Vec<i64> {
67    let mut strides = vec![0i64; shape.len()];
68    let mut stride = 1i64;
69    for (slot, dim) in strides.iter_mut().zip(shape).rev() {
70        *slot = stride;
71        stride *= *dim;
72    }
73    strides
74}
75
76/// An n-dimensional, immutable tensor backed by shared [`Storage`].
77///
78/// The layout follows DLPack conventions: C-order shape, element-unit
79/// strides (`None` means C-contiguous), and a byte offset into the backing
80/// storage. Constructors validate that every addressable element falls
81/// inside the storage, so accessors never fail.
82///
83/// Equality is logical: two tensors are equal when dtype, shape, and
84/// element bytes (in C order) match, regardless of how they are laid out
85/// in storage.
86#[derive(Debug, Clone)]
87pub struct Tensor {
88    storage: Storage,
89    dtype: DType,
90    shape: Vec<i64>,
91    strides: Option<Vec<i64>>,
92    byte_offset: usize,
93}
94
95impl Tensor {
96    /// Adopt `data` as a C-contiguous tensor without copying.
97    ///
98    /// `data.len()` must equal exactly `numel * dtype_size`. The buffer keeps
99    /// its original allocation, so no alignment is guaranteed.
100    pub fn from_vec(data: Vec<u8>, shape: Vec<i64>, dtype: DType) -> Result<Self, TensorError> {
101        let expected = checked_nbytes(&shape, dtype)?;
102        if data.len() != expected {
103            return Err(TensorError::ByteLengthMismatch {
104                expected,
105                actual: data.len(),
106            });
107        }
108        Self::from_storage(Storage::from_vec(data), dtype, shape, None, 0)
109    }
110
111    /// Copy `data` into a fresh 64-byte-aligned C-contiguous tensor.
112    ///
113    /// `data.len()` must equal exactly `numel * dtype_size`.
114    pub fn from_slice(data: &[u8], shape: &[i64], dtype: DType) -> Result<Self, TensorError> {
115        let expected = checked_nbytes(shape, dtype)?;
116        if data.len() != expected {
117            return Err(TensorError::ByteLengthMismatch {
118                expected,
119                actual: data.len(),
120            });
121        }
122        Self::from_storage(Storage::from_slice(data), dtype, shape.to_vec(), None, 0)
123    }
124
125    /// A zero-filled, 64-byte-aligned C-contiguous tensor.
126    pub fn zeros(shape: &[i64], dtype: DType) -> Result<Self, TensorError> {
127        let nbytes = checked_nbytes(shape, dtype)?;
128        Self::from_storage(Storage::zeroed(nbytes), dtype, shape.to_vec(), None, 0)
129    }
130
131    /// A tensor view over `storage` with full layout control.
132    ///
133    /// `strides` are in element units; `None` means C-contiguous. Dims and
134    /// strides must be non-negative and every addressable element must fall
135    /// inside the storage.
136    pub fn from_storage(
137        storage: Storage,
138        dtype: DType,
139        shape: Vec<i64>,
140        strides: Option<Vec<i64>>,
141        byte_offset: usize,
142    ) -> Result<Self, TensorError> {
143        if dtype == DType::Unspecified {
144            return Err(TensorError::UnspecifiedDtype);
145        }
146        if let Some(strides) = &strides {
147            if strides.len() != shape.len() {
148                return Err(TensorError::StrideRankMismatch {
149                    strides: strides.len(),
150                    shape: shape.len(),
151                });
152            }
153            for &stride in strides {
154                if stride < 0 {
155                    return Err(TensorError::NegativeStride(stride));
156                }
157            }
158        }
159        let required = required_bytes(&shape, strides.as_deref(), dtype)?;
160        // The offset must stay inside the storage even for zero-element
161        // views: accessors slice `storage[byte_offset..]` unconditionally.
162        if byte_offset > storage.len() || required > storage.len() - byte_offset {
163            return Err(TensorError::OutOfBounds {
164                required,
165                byte_offset,
166                available: storage.len(),
167            });
168        }
169        Ok(Self {
170            storage,
171            dtype,
172            shape,
173            strides,
174            byte_offset,
175        })
176    }
177
178    /// Element data type.
179    pub fn dtype(&self) -> DType {
180        self.dtype
181    }
182
183    /// Dimension sizes in C order.
184    pub fn shape(&self) -> &[i64] {
185        &self.shape
186    }
187
188    /// Explicit element-unit strides, or `None` when C-contiguous.
189    pub fn strides(&self) -> Option<&[i64]> {
190        self.strides.as_deref()
191    }
192
193    /// Element-unit strides, materializing the C-contiguous default.
194    pub fn effective_strides(&self) -> Cow<'_, [i64]> {
195        match &self.strides {
196            Some(strides) => Cow::Borrowed(strides),
197            None => Cow::Owned(contiguous_strides(&self.shape)),
198        }
199    }
200
201    /// Offset in bytes from the start of the storage to the first element.
202    pub fn byte_offset(&self) -> usize {
203        self.byte_offset
204    }
205
206    /// Device the storage lives on. Always [`Device::Cpu`] today.
207    pub fn device(&self) -> Device {
208        Device::Cpu
209    }
210
211    /// The shared backing storage.
212    pub fn storage(&self) -> &Storage {
213        &self.storage
214    }
215
216    /// Number of elements.
217    pub fn numel(&self) -> usize {
218        self.shape.iter().map(|&dim| dim as usize).product()
219    }
220
221    /// Logical size of the element data in bytes (`numel * dtype_size`).
222    pub fn nbytes(&self) -> usize {
223        self.numel() * dtype_size(self.dtype)
224    }
225
226    /// Whether elements are laid out C-contiguously.
227    pub fn is_contiguous(&self) -> bool {
228        let Some(strides) = &self.strides else {
229            return true;
230        };
231        let mut expected = 1i64;
232        for (&dim, &stride) in self.shape.iter().zip(strides).rev() {
233            if dim == 0 {
234                return true;
235            }
236            if dim != 1 {
237                if stride != expected {
238                    return false;
239                }
240                expected *= dim;
241            }
242        }
243        true
244    }
245
246    /// A tensor with the same elements and a new shape.
247    ///
248    /// At most one dimension may be `-1`, which is inferred from the
249    /// element count. Returns a zero-copy view sharing this tensor's
250    /// storage when the layout is contiguous, and a contiguous copy
251    /// otherwise.
252    pub fn reshape(&self, shape: &[i64]) -> Result<Self, TensorError> {
253        let shape = self.resolve_reshape_dims(shape)?;
254        let to = checked_numel(&shape)?;
255        let from = self.numel();
256        if from != to {
257            return Err(TensorError::NumelMismatch { from, to });
258        }
259        if self.is_contiguous() {
260            Self::from_storage(
261                self.storage.clone(),
262                self.dtype,
263                shape,
264                None,
265                self.byte_offset,
266            )
267        } else {
268            let storage = Storage::aligned_with(self.nbytes(), |buf| self.gather_into(buf));
269            Self::from_storage(storage, self.dtype, shape, None, 0)
270        }
271    }
272
273    /// Replace a single `-1` dimension with the size inferred from this
274    /// tensor's element count.
275    fn resolve_reshape_dims(&self, shape: &[i64]) -> Result<Vec<i64>, TensorError> {
276        let wildcards = shape.iter().filter(|&&dim| dim == -1).count();
277        if wildcards > 1 {
278            return Err(TensorError::AmbiguousReshape);
279        }
280        if wildcards == 0 {
281            return Ok(shape.to_vec());
282        }
283        let mut known = 1usize;
284        for &dim in shape {
285            if dim < -1 {
286                return Err(TensorError::NegativeDim(dim));
287            }
288            if dim >= 0 {
289                known = known
290                    .checked_mul(dim as usize)
291                    .ok_or(TensorError::Overflow)?;
292            }
293        }
294        let from = self.numel();
295        if known == 0 || !from.is_multiple_of(known) {
296            return Err(TensorError::NumelMismatch { from, to: known });
297        }
298        let inferred = (from / known) as i64;
299        Ok(shape
300            .iter()
301            .map(|&dim| if dim == -1 { inferred } else { dim })
302            .collect())
303    }
304
305    /// The element bytes in C order.
306    ///
307    /// Borrows from storage when the layout is contiguous; gathers into a
308    /// fresh buffer otherwise.
309    pub fn to_contiguous_bytes(&self) -> Cow<'_, [u8]> {
310        if self.is_contiguous() {
311            let start = self.byte_offset;
312            return Cow::Borrowed(&self.storage.as_slice()[start..start + self.nbytes()]);
313        }
314        let mut out = Vec::with_capacity(self.nbytes());
315        self.gather_into(&mut out);
316        Cow::Owned(out)
317    }
318
319    /// Stack tensors of identical dtype and shape along a new leading axis.
320    ///
321    /// The result is a fresh 64-byte-aligned contiguous tensor of shape
322    /// `[tensors.len(), ..shape]`.
323    pub fn stack(tensors: &[Tensor]) -> Result<Tensor, TensorError> {
324        let Some(first) = tensors.first() else {
325            return Err(TensorError::EmptyStack);
326        };
327        for (index, tensor) in tensors.iter().enumerate() {
328            if tensor.dtype != first.dtype || tensor.shape != first.shape {
329                return Err(TensorError::StackMismatch { index });
330            }
331        }
332        let total = first
333            .nbytes()
334            .checked_mul(tensors.len())
335            .ok_or(TensorError::Overflow)?;
336        let mut shape = Vec::with_capacity(first.shape.len() + 1);
337        shape.push(tensors.len() as i64);
338        shape.extend_from_slice(&first.shape);
339        let storage = Storage::aligned_with(total, |buf| {
340            for tensor in tensors {
341                match tensor.to_contiguous_bytes() {
342                    Cow::Borrowed(bytes) => buf.extend_from_slice(bytes),
343                    Cow::Owned(bytes) => buf.extend_from_slice(&bytes),
344                }
345            }
346        });
347        Self::from_storage(storage, first.dtype, shape, None, 0)
348    }
349
350    /// Split along axis 0 into zero-copy views sharing this storage.
351    pub fn unstack(&self) -> Result<Vec<Tensor>, TensorError> {
352        if self.shape.is_empty() {
353            return Err(TensorError::UnstackScalar);
354        }
355        let count = self.shape[0] as usize;
356        let inner_shape = &self.shape[1..];
357        let strides = self.effective_strides();
358        let outer_stride_bytes = strides[0] as usize * dtype_size(self.dtype);
359        let inner_strides = self.strides.as_ref().map(|_| strides[1..].to_vec());
360        (0..count)
361            .map(|index| {
362                Self::from_storage(
363                    self.storage.clone(),
364                    self.dtype,
365                    inner_shape.to_vec(),
366                    inner_strides.clone(),
367                    self.byte_offset + index * outer_stride_bytes,
368                )
369            })
370            .collect()
371    }
372
373    /// Copy the element bytes in C order into `out`.
374    fn gather_into(&self, out: &mut Vec<u8>) {
375        let itemsize = dtype_size(self.dtype);
376        let strides = self.effective_strides();
377        let data = &self.storage.as_slice()[self.byte_offset..];
378        let mut index = vec![0usize; self.shape.len()];
379        for _ in 0..self.numel() {
380            let element: usize = index
381                .iter()
382                .zip(strides.iter())
383                .map(|(&i, &stride)| i * stride as usize)
384                .sum();
385            let start = element * itemsize;
386            out.extend_from_slice(&data[start..start + itemsize]);
387            for axis in (0..index.len()).rev() {
388                index[axis] += 1;
389                if (index[axis] as i64) < self.shape[axis] {
390                    break;
391                }
392                index[axis] = 0;
393            }
394        }
395    }
396}
397
398impl PartialEq for Tensor {
399    fn eq(&self, other: &Self) -> bool {
400        self.dtype == other.dtype
401            && self.shape == other.shape
402            && self.to_contiguous_bytes() == other.to_contiguous_bytes()
403    }
404}
405
406fn checked_numel(shape: &[i64]) -> Result<usize, TensorError> {
407    let mut numel = 1usize;
408    for &dim in shape {
409        if dim < 0 {
410            return Err(TensorError::NegativeDim(dim));
411        }
412        numel = numel
413            .checked_mul(dim as usize)
414            .ok_or(TensorError::Overflow)?;
415    }
416    Ok(numel)
417}
418
419fn checked_nbytes(shape: &[i64], dtype: DType) -> Result<usize, TensorError> {
420    if dtype == DType::Unspecified {
421        return Err(TensorError::UnspecifiedDtype);
422    }
423    checked_numel(shape)?
424        .checked_mul(dtype_size(dtype))
425        .ok_or(TensorError::Overflow)
426}
427
428/// Bytes a view must be able to address past its byte offset: one item plus
429/// the span reached by the largest in-bounds index on every axis.
430fn required_bytes(
431    shape: &[i64],
432    strides: Option<&[i64]>,
433    dtype: DType,
434) -> Result<usize, TensorError> {
435    let numel = checked_numel(shape)?;
436    if numel == 0 {
437        return Ok(0);
438    }
439    let itemsize = dtype_size(dtype);
440    let Some(strides) = strides else {
441        return numel.checked_mul(itemsize).ok_or(TensorError::Overflow);
442    };
443    let mut last_element = 0usize;
444    for (&dim, &stride) in shape.iter().zip(strides) {
445        let span = (dim as usize - 1)
446            .checked_mul(stride as usize)
447            .ok_or(TensorError::Overflow)?;
448        last_element = last_element
449            .checked_add(span)
450            .ok_or(TensorError::Overflow)?;
451    }
452    last_element
453        .checked_add(1)
454        .ok_or(TensorError::Overflow)?
455        .checked_mul(itemsize)
456        .ok_or(TensorError::Overflow)
457}
458
459#[cfg(test)]
460mod tests {
461    use super::*;
462
463    fn f32_bytes(values: &[f32]) -> Vec<u8> {
464        values.iter().flat_map(|v| v.to_le_bytes()).collect()
465    }
466
467    #[test]
468    fn test_from_vec_adopts_and_validates_length() {
469        let tensor = Tensor::from_vec(f32_bytes(&[1.0, 2.0, 3.0]), vec![3], DType::Float32)
470            .expect("valid tensor");
471        assert_eq!(tensor.shape(), &[3]);
472        assert_eq!(tensor.numel(), 3);
473        assert_eq!(tensor.nbytes(), 12);
474        assert!(tensor.is_contiguous());
475        assert_eq!(tensor.strides(), None);
476        assert_eq!(tensor.effective_strides().as_ref(), &[1]);
477        assert_eq!(tensor.device(), Device::Cpu);
478
479        assert_eq!(
480            Tensor::from_vec(vec![0u8; 11], vec![3], DType::Float32),
481            Err(TensorError::ByteLengthMismatch {
482                expected: 12,
483                actual: 11
484            })
485        );
486    }
487
488    #[test]
489    fn test_constructor_rejects_invalid_inputs() {
490        assert_eq!(
491            Tensor::from_vec(vec![], vec![2], DType::Unspecified),
492            Err(TensorError::UnspecifiedDtype)
493        );
494        assert_eq!(
495            Tensor::from_vec(vec![], vec![-1], DType::Float32),
496            Err(TensorError::NegativeDim(-1))
497        );
498        assert_eq!(
499            Tensor::from_storage(
500                Storage::zeroed(8),
501                DType::Float32,
502                vec![2],
503                Some(vec![1, 1]),
504                0
505            ),
506            Err(TensorError::StrideRankMismatch {
507                strides: 2,
508                shape: 1
509            })
510        );
511        assert_eq!(
512            Tensor::from_storage(
513                Storage::zeroed(8),
514                DType::Float32,
515                vec![2],
516                Some(vec![-1]),
517                0
518            ),
519            Err(TensorError::NegativeStride(-1))
520        );
521        assert_eq!(
522            Tensor::from_storage(Storage::zeroed(8), DType::Float32, vec![3], None, 0),
523            Err(TensorError::OutOfBounds {
524                required: 12,
525                byte_offset: 0,
526                available: 8
527            })
528        );
529        // Strided view reaching past the storage end.
530        assert_eq!(
531            Tensor::from_storage(
532                Storage::zeroed(12),
533                DType::Float32,
534                vec![2],
535                Some(vec![3]),
536                0
537            ),
538            Err(TensorError::OutOfBounds {
539                required: 16,
540                byte_offset: 0,
541                available: 12
542            })
543        );
544        assert_eq!(
545            Tensor::from_vec(vec![], vec![i64::MAX, i64::MAX], DType::Float32),
546            Err(TensorError::Overflow)
547        );
548    }
549
550    #[test]
551    fn test_zeros_is_aligned_and_zero_filled() {
552        let tensor = Tensor::zeros(&[4, 4], DType::Int32).expect("valid tensor");
553        assert_eq!(tensor.nbytes(), 64);
554        assert!(tensor.to_contiguous_bytes().iter().all(|&b| b == 0));
555        assert_eq!(tensor.storage().as_slice().as_ptr() as usize % 64, 0);
556    }
557
558    #[test]
559    fn test_scalar_tensor() {
560        let tensor =
561            Tensor::from_slice(&1.0f64.to_le_bytes(), &[], DType::Float64).expect("valid tensor");
562        assert_eq!(tensor.shape(), &[] as &[i64]);
563        assert_eq!(tensor.numel(), 1);
564        assert_eq!(tensor.to_contiguous_bytes().as_ref(), 1.0f64.to_le_bytes());
565    }
566
567    #[test]
568    fn test_reshape_contiguous_is_view() {
569        let tensor = Tensor::from_slice(
570            &f32_bytes(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]),
571            &[2, 3],
572            DType::Float32,
573        )
574        .expect("valid tensor");
575        let reshaped = tensor.reshape(&[3, 2]).expect("valid reshape");
576        assert!(reshaped.storage().ptr_eq(tensor.storage()));
577        assert_eq!(reshaped.shape(), &[3, 2]);
578        assert_eq!(reshaped.byte_offset(), tensor.byte_offset());
579        assert_eq!(
580            tensor.reshape(&[4, 2]),
581            Err(TensorError::NumelMismatch { from: 6, to: 8 })
582        );
583    }
584
585    #[test]
586    fn test_reshape_infers_one_dimension() {
587        let tensor = Tensor::zeros(&[2, 3, 4], DType::Uint8).expect("valid tensor");
588
589        let inferred = tensor.reshape(&[2, -1, 3]).expect("valid reshape");
590        assert_eq!(inferred.shape(), &[2, 4, 3]);
591        assert!(inferred.storage().ptr_eq(tensor.storage()));
592
593        let flat = tensor.reshape(&[-1]).expect("valid reshape");
594        assert_eq!(flat.shape(), &[24]);
595
596        assert_eq!(
597            tensor.reshape(&[-1, -1]),
598            Err(TensorError::AmbiguousReshape)
599        );
600        assert_eq!(
601            tensor.reshape(&[-1, 5]),
602            Err(TensorError::NumelMismatch { from: 24, to: 5 })
603        );
604        assert_eq!(tensor.reshape(&[-2, 4]), Err(TensorError::NegativeDim(-2)));
605    }
606
607    #[test]
608    fn test_reshape_inference_on_empty_tensors() {
609        let empty = Tensor::zeros(&[0, 3], DType::Float32).expect("valid tensor");
610
611        // 0 elements / 3 known => inferred 0.
612        let inferred = empty.reshape(&[-1, 3]).expect("valid reshape");
613        assert_eq!(inferred.shape(), &[0, 3]);
614
615        // A zero-sized known dimension leaves -1 ambiguous.
616        assert_eq!(
617            empty.reshape(&[0, -1]),
618            Err(TensorError::NumelMismatch { from: 0, to: 0 })
619        );
620    }
621
622    #[test]
623    fn test_reshape_strided_copies() {
624        // Column-major 2x2 layout: storage [1, 3, 2, 4] viewed with strides [1, 2]
625        // reads as [[1, 2], [3, 4]].
626        let storage = Storage::from_slice(&f32_bytes(&[1.0, 3.0, 2.0, 4.0]));
627        let tensor = Tensor::from_storage(storage, DType::Float32, vec![2, 2], Some(vec![1, 2]), 0)
628            .expect("valid tensor");
629        assert!(!tensor.is_contiguous());
630
631        let reshaped = tensor.reshape(&[4]).expect("valid reshape");
632        assert!(!reshaped.storage().ptr_eq(tensor.storage()));
633        assert!(reshaped.is_contiguous());
634        assert_eq!(
635            reshaped.to_contiguous_bytes().as_ref(),
636            f32_bytes(&[1.0, 2.0, 3.0, 4.0]).as_slice()
637        );
638    }
639
640    #[test]
641    fn test_to_contiguous_bytes_borrows_when_contiguous() {
642        let tensor =
643            Tensor::from_slice(&f32_bytes(&[1.0, 2.0]), &[2], DType::Float32).expect("valid");
644        assert!(matches!(tensor.to_contiguous_bytes(), Cow::Borrowed(_)));
645
646        let storage = Storage::from_slice(&f32_bytes(&[1.0, 2.0, 3.0, 4.0]));
647        let strided = Tensor::from_storage(storage, DType::Float32, vec![2], Some(vec![2]), 0)
648            .expect("valid tensor");
649        assert!(!strided.is_contiguous());
650        let gathered = strided.to_contiguous_bytes();
651        assert!(matches!(gathered, Cow::Owned(_)));
652        assert_eq!(gathered.as_ref(), f32_bytes(&[1.0, 3.0]).as_slice());
653    }
654
655    #[test]
656    fn test_strided_gather_multi_dimensional() {
657        // 3x4 storage; view the 3x2 sub-tensor of even columns.
658        let values: Vec<f32> = (0..12).map(|v| v as f32).collect();
659        let storage = Storage::from_slice(&f32_bytes(&values));
660        let view = Tensor::from_storage(storage, DType::Float32, vec![3, 2], Some(vec![4, 2]), 0)
661            .expect("valid tensor");
662        assert_eq!(
663            view.to_contiguous_bytes().as_ref(),
664            f32_bytes(&[0.0, 2.0, 4.0, 6.0, 8.0, 10.0]).as_slice()
665        );
666    }
667
668    #[test]
669    fn test_stack_and_unstack_roundtrip() {
670        let tensors: Vec<Tensor> = (0..3)
671            .map(|i| {
672                Tensor::from_slice(
673                    &f32_bytes(&[i as f32, i as f32 + 0.5]),
674                    &[2],
675                    DType::Float32,
676                )
677                .expect("valid tensor")
678            })
679            .collect();
680
681        let stacked = Tensor::stack(&tensors).expect("valid stack");
682        assert_eq!(stacked.shape(), &[3, 2]);
683        assert!(stacked.is_contiguous());
684        assert_eq!(stacked.storage().as_slice().as_ptr() as usize % 64, 0);
685
686        let views = stacked.unstack().expect("valid unstack");
687        assert_eq!(views.len(), 3);
688        for (index, (view, original)) in views.iter().zip(&tensors).enumerate() {
689            assert!(view.storage().ptr_eq(stacked.storage()));
690            assert_eq!(view.byte_offset(), index * 8);
691            assert_eq!(view, original);
692        }
693    }
694
695    #[test]
696    fn test_stack_rejects_empty_and_mismatched() {
697        assert_eq!(Tensor::stack(&[]), Err(TensorError::EmptyStack));
698
699        let a = Tensor::zeros(&[2], DType::Float32).expect("valid tensor");
700        let b = Tensor::zeros(&[3], DType::Float32).expect("valid tensor");
701        let c = Tensor::zeros(&[2], DType::Int32).expect("valid tensor");
702        assert_eq!(
703            Tensor::stack(&[a.clone(), b]),
704            Err(TensorError::StackMismatch { index: 1 })
705        );
706        assert_eq!(
707            Tensor::stack(&[a, c]),
708            Err(TensorError::StackMismatch { index: 1 })
709        );
710    }
711
712    #[test]
713    fn test_unstack_scalar_fails() {
714        let scalar = Tensor::zeros(&[], DType::Float32).expect("valid tensor");
715        assert_eq!(scalar.unstack(), Err(TensorError::UnstackScalar));
716    }
717
718    #[test]
719    fn test_partial_eq_is_logical() {
720        // Strided view of [1, 3] vs a contiguous [1, 3]: equal content,
721        // different storage layout.
722        let storage = Storage::from_slice(&f32_bytes(&[1.0, 2.0, 3.0, 4.0]));
723        let strided = Tensor::from_storage(storage, DType::Float32, vec![2], Some(vec![2]), 0)
724            .expect("valid tensor");
725        let contiguous =
726            Tensor::from_slice(&f32_bytes(&[1.0, 3.0]), &[2], DType::Float32).expect("valid");
727        assert_eq!(strided, contiguous);
728
729        let other_dtype = Tensor::from_slice(&[0u8; 2], &[2], DType::Uint8).expect("valid");
730        let same_bytes = Tensor::from_slice(&[0u8; 2], &[2], DType::Int8).expect("valid");
731        assert_ne!(other_dtype, same_bytes);
732
733        let flat = Tensor::zeros(&[4], DType::Float32).expect("valid");
734        let square = Tensor::zeros(&[2, 2], DType::Float32).expect("valid");
735        assert_ne!(flat, square);
736    }
737
738    #[test]
739    fn test_view_with_byte_offset() {
740        let storage = Storage::from_slice(&f32_bytes(&[1.0, 2.0, 3.0, 4.0]));
741        let tail =
742            Tensor::from_storage(storage, DType::Float32, vec![2], None, 8).expect("valid tensor");
743        assert_eq!(
744            tail.to_contiguous_bytes().as_ref(),
745            f32_bytes(&[3.0, 4.0]).as_slice()
746        );
747        assert_eq!(tail.byte_offset(), 8);
748    }
749
750    #[test]
751    fn test_empty_view_offset_must_stay_inside_storage() {
752        // Zero-element views need no bytes, but an offset past the storage
753        // would make to_contiguous_bytes() slice out of bounds.
754        let storage = Storage::zeroed(8);
755        assert_eq!(
756            Tensor::from_storage(storage.clone(), DType::Float32, vec![0], None, 9),
757            Err(TensorError::OutOfBounds {
758                required: 0,
759                byte_offset: 9,
760                available: 8
761            })
762        );
763        // An offset exactly at the end is fine for an empty view.
764        let at_end = Tensor::from_storage(storage, DType::Float32, vec![0], None, 8)
765            .expect("valid empty view");
766        assert!(at_end.to_contiguous_bytes().is_empty());
767    }
768
769    #[test]
770    fn test_empty_tensor() {
771        let tensor = Tensor::zeros(&[0, 3], DType::Float32).expect("valid tensor");
772        assert_eq!(tensor.numel(), 0);
773        assert_eq!(tensor.nbytes(), 0);
774        assert!(tensor.is_contiguous());
775        assert!(tensor.to_contiguous_bytes().is_empty());
776        let views = tensor.unstack().expect("valid unstack");
777        assert!(views.is_empty());
778    }
779
780    #[test]
781    fn test_contiguous_strides_table() {
782        assert_eq!(contiguous_strides(&[]), Vec::<i64>::new());
783        assert_eq!(contiguous_strides(&[5]), vec![1]);
784        assert_eq!(contiguous_strides(&[2, 3]), vec![3, 1]);
785        assert_eq!(contiguous_strides(&[2, 3, 4]), vec![12, 4, 1]);
786        assert_eq!(contiguous_strides(&[0, 3]), vec![3, 1]);
787    }
788
789    #[test]
790    fn test_explicit_contiguous_strides_detected() {
791        let storage = Storage::from_slice(&f32_bytes(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]));
792        let tensor = Tensor::from_storage(storage, DType::Float32, vec![2, 3], Some(vec![3, 1]), 0)
793            .expect("valid tensor");
794        assert!(tensor.is_contiguous());
795        assert!(matches!(tensor.to_contiguous_bytes(), Cow::Borrowed(_)));
796    }
797}