1mod dlpack;
2#[cfg(test)]
3mod proptests;
4mod storage;
5
6pub use dlpack::{DLPackType, dlpack_type, dtype_from_dlpack};
7pub use storage::Storage;
8
9use std::borrow::Cow;
10
11use thiserror::Error;
12
13use crate::dtype::{DType, dtype_size};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17#[repr(i32)]
18#[non_exhaustive]
19pub enum Device {
20 Cpu = 1,
22}
23
24impl From<Device> for i32 {
25 fn from(value: Device) -> Self {
26 value as i32
27 }
28}
29
30#[derive(Error, Debug, Clone, PartialEq, Eq)]
32pub enum TensorError {
33 #[error("tensor dtype must be specified")]
34 UnspecifiedDtype,
35 #[error("negative dimension {0} in shape")]
36 NegativeDim(i64),
37 #[error("negative stride {0}")]
38 NegativeStride(i64),
39 #[error("strides rank {strides} does not match shape rank {shape}")]
40 StrideRankMismatch { strides: usize, shape: usize },
41 #[error("data is {actual} bytes but shape and dtype require exactly {expected}")]
42 ByteLengthMismatch { expected: usize, actual: usize },
43 #[error(
44 "view requires {required} bytes at byte offset {byte_offset} but storage holds {available}"
45 )]
46 OutOfBounds {
47 required: usize,
48 byte_offset: usize,
49 available: usize,
50 },
51 #[error("cannot reshape {from} elements into {to}")]
52 NumelMismatch { from: usize, to: usize },
53 #[error("reshape shape may contain at most one -1")]
54 AmbiguousReshape,
55 #[error("stack requires at least one tensor")]
56 EmptyStack,
57 #[error("stack requires uniform dtype and shape; tensor {index} differs")]
58 StackMismatch { index: usize },
59 #[error("cannot unstack a 0-dimensional tensor")]
60 UnstackScalar,
61 #[error("tensor size overflows usize")]
62 Overflow,
63}
64
65pub fn contiguous_strides(shape: &[i64]) -> Vec<i64> {
67 let mut strides = vec![0i64; shape.len()];
68 let mut stride = 1i64;
69 for (slot, dim) in strides.iter_mut().zip(shape).rev() {
70 *slot = stride;
71 stride *= *dim;
72 }
73 strides
74}
75
76#[derive(Debug, Clone)]
87pub struct Tensor {
88 storage: Storage,
89 dtype: DType,
90 shape: Vec<i64>,
91 strides: Option<Vec<i64>>,
92 byte_offset: usize,
93}
94
95impl Tensor {
96 pub fn from_vec(data: Vec<u8>, shape: Vec<i64>, dtype: DType) -> Result<Self, TensorError> {
101 let expected = checked_nbytes(&shape, dtype)?;
102 if data.len() != expected {
103 return Err(TensorError::ByteLengthMismatch {
104 expected,
105 actual: data.len(),
106 });
107 }
108 Self::from_storage(Storage::from_vec(data), dtype, shape, None, 0)
109 }
110
111 pub fn from_slice(data: &[u8], shape: &[i64], dtype: DType) -> Result<Self, TensorError> {
115 let expected = checked_nbytes(shape, dtype)?;
116 if data.len() != expected {
117 return Err(TensorError::ByteLengthMismatch {
118 expected,
119 actual: data.len(),
120 });
121 }
122 Self::from_storage(Storage::from_slice(data), dtype, shape.to_vec(), None, 0)
123 }
124
125 pub fn zeros(shape: &[i64], dtype: DType) -> Result<Self, TensorError> {
127 let nbytes = checked_nbytes(shape, dtype)?;
128 Self::from_storage(Storage::zeroed(nbytes), dtype, shape.to_vec(), None, 0)
129 }
130
131 pub fn from_storage(
137 storage: Storage,
138 dtype: DType,
139 shape: Vec<i64>,
140 strides: Option<Vec<i64>>,
141 byte_offset: usize,
142 ) -> Result<Self, TensorError> {
143 if dtype == DType::Unspecified {
144 return Err(TensorError::UnspecifiedDtype);
145 }
146 if let Some(strides) = &strides {
147 if strides.len() != shape.len() {
148 return Err(TensorError::StrideRankMismatch {
149 strides: strides.len(),
150 shape: shape.len(),
151 });
152 }
153 for &stride in strides {
154 if stride < 0 {
155 return Err(TensorError::NegativeStride(stride));
156 }
157 }
158 }
159 let required = required_bytes(&shape, strides.as_deref(), dtype)?;
160 if byte_offset > storage.len() || required > storage.len() - byte_offset {
163 return Err(TensorError::OutOfBounds {
164 required,
165 byte_offset,
166 available: storage.len(),
167 });
168 }
169 Ok(Self {
170 storage,
171 dtype,
172 shape,
173 strides,
174 byte_offset,
175 })
176 }
177
178 pub fn dtype(&self) -> DType {
180 self.dtype
181 }
182
183 pub fn shape(&self) -> &[i64] {
185 &self.shape
186 }
187
188 pub fn strides(&self) -> Option<&[i64]> {
190 self.strides.as_deref()
191 }
192
193 pub fn effective_strides(&self) -> Cow<'_, [i64]> {
195 match &self.strides {
196 Some(strides) => Cow::Borrowed(strides),
197 None => Cow::Owned(contiguous_strides(&self.shape)),
198 }
199 }
200
201 pub fn byte_offset(&self) -> usize {
203 self.byte_offset
204 }
205
206 pub fn device(&self) -> Device {
208 Device::Cpu
209 }
210
211 pub fn storage(&self) -> &Storage {
213 &self.storage
214 }
215
216 pub fn numel(&self) -> usize {
218 self.shape.iter().map(|&dim| dim as usize).product()
219 }
220
221 pub fn nbytes(&self) -> usize {
223 self.numel() * dtype_size(self.dtype)
224 }
225
226 pub fn is_contiguous(&self) -> bool {
228 let Some(strides) = &self.strides else {
229 return true;
230 };
231 let mut expected = 1i64;
232 for (&dim, &stride) in self.shape.iter().zip(strides).rev() {
233 if dim == 0 {
234 return true;
235 }
236 if dim != 1 {
237 if stride != expected {
238 return false;
239 }
240 expected *= dim;
241 }
242 }
243 true
244 }
245
246 pub fn reshape(&self, shape: &[i64]) -> Result<Self, TensorError> {
253 let shape = self.resolve_reshape_dims(shape)?;
254 let to = checked_numel(&shape)?;
255 let from = self.numel();
256 if from != to {
257 return Err(TensorError::NumelMismatch { from, to });
258 }
259 if self.is_contiguous() {
260 Self::from_storage(
261 self.storage.clone(),
262 self.dtype,
263 shape,
264 None,
265 self.byte_offset,
266 )
267 } else {
268 let storage = Storage::aligned_with(self.nbytes(), |buf| self.gather_into(buf));
269 Self::from_storage(storage, self.dtype, shape, None, 0)
270 }
271 }
272
273 fn resolve_reshape_dims(&self, shape: &[i64]) -> Result<Vec<i64>, TensorError> {
276 let wildcards = shape.iter().filter(|&&dim| dim == -1).count();
277 if wildcards > 1 {
278 return Err(TensorError::AmbiguousReshape);
279 }
280 if wildcards == 0 {
281 return Ok(shape.to_vec());
282 }
283 let mut known = 1usize;
284 for &dim in shape {
285 if dim < -1 {
286 return Err(TensorError::NegativeDim(dim));
287 }
288 if dim >= 0 {
289 known = known
290 .checked_mul(dim as usize)
291 .ok_or(TensorError::Overflow)?;
292 }
293 }
294 let from = self.numel();
295 if known == 0 || !from.is_multiple_of(known) {
296 return Err(TensorError::NumelMismatch { from, to: known });
297 }
298 let inferred = (from / known) as i64;
299 Ok(shape
300 .iter()
301 .map(|&dim| if dim == -1 { inferred } else { dim })
302 .collect())
303 }
304
305 pub fn to_contiguous_bytes(&self) -> Cow<'_, [u8]> {
310 if self.is_contiguous() {
311 let start = self.byte_offset;
312 return Cow::Borrowed(&self.storage.as_slice()[start..start + self.nbytes()]);
313 }
314 let mut out = Vec::with_capacity(self.nbytes());
315 self.gather_into(&mut out);
316 Cow::Owned(out)
317 }
318
319 pub fn stack(tensors: &[Tensor]) -> Result<Tensor, TensorError> {
324 let Some(first) = tensors.first() else {
325 return Err(TensorError::EmptyStack);
326 };
327 for (index, tensor) in tensors.iter().enumerate() {
328 if tensor.dtype != first.dtype || tensor.shape != first.shape {
329 return Err(TensorError::StackMismatch { index });
330 }
331 }
332 let total = first
333 .nbytes()
334 .checked_mul(tensors.len())
335 .ok_or(TensorError::Overflow)?;
336 let mut shape = Vec::with_capacity(first.shape.len() + 1);
337 shape.push(tensors.len() as i64);
338 shape.extend_from_slice(&first.shape);
339 let storage = Storage::aligned_with(total, |buf| {
340 for tensor in tensors {
341 match tensor.to_contiguous_bytes() {
342 Cow::Borrowed(bytes) => buf.extend_from_slice(bytes),
343 Cow::Owned(bytes) => buf.extend_from_slice(&bytes),
344 }
345 }
346 });
347 Self::from_storage(storage, first.dtype, shape, None, 0)
348 }
349
350 pub fn unstack(&self) -> Result<Vec<Tensor>, TensorError> {
352 if self.shape.is_empty() {
353 return Err(TensorError::UnstackScalar);
354 }
355 let count = self.shape[0] as usize;
356 let inner_shape = &self.shape[1..];
357 let strides = self.effective_strides();
358 let outer_stride_bytes = strides[0] as usize * dtype_size(self.dtype);
359 let inner_strides = self.strides.as_ref().map(|_| strides[1..].to_vec());
360 (0..count)
361 .map(|index| {
362 Self::from_storage(
363 self.storage.clone(),
364 self.dtype,
365 inner_shape.to_vec(),
366 inner_strides.clone(),
367 self.byte_offset + index * outer_stride_bytes,
368 )
369 })
370 .collect()
371 }
372
373 fn gather_into(&self, out: &mut Vec<u8>) {
375 let itemsize = dtype_size(self.dtype);
376 let strides = self.effective_strides();
377 let data = &self.storage.as_slice()[self.byte_offset..];
378 let mut index = vec![0usize; self.shape.len()];
379 for _ in 0..self.numel() {
380 let element: usize = index
381 .iter()
382 .zip(strides.iter())
383 .map(|(&i, &stride)| i * stride as usize)
384 .sum();
385 let start = element * itemsize;
386 out.extend_from_slice(&data[start..start + itemsize]);
387 for axis in (0..index.len()).rev() {
388 index[axis] += 1;
389 if (index[axis] as i64) < self.shape[axis] {
390 break;
391 }
392 index[axis] = 0;
393 }
394 }
395 }
396}
397
398impl PartialEq for Tensor {
399 fn eq(&self, other: &Self) -> bool {
400 self.dtype == other.dtype
401 && self.shape == other.shape
402 && self.to_contiguous_bytes() == other.to_contiguous_bytes()
403 }
404}
405
406fn checked_numel(shape: &[i64]) -> Result<usize, TensorError> {
407 let mut numel = 1usize;
408 for &dim in shape {
409 if dim < 0 {
410 return Err(TensorError::NegativeDim(dim));
411 }
412 numel = numel
413 .checked_mul(dim as usize)
414 .ok_or(TensorError::Overflow)?;
415 }
416 Ok(numel)
417}
418
419fn checked_nbytes(shape: &[i64], dtype: DType) -> Result<usize, TensorError> {
420 if dtype == DType::Unspecified {
421 return Err(TensorError::UnspecifiedDtype);
422 }
423 checked_numel(shape)?
424 .checked_mul(dtype_size(dtype))
425 .ok_or(TensorError::Overflow)
426}
427
428fn required_bytes(
431 shape: &[i64],
432 strides: Option<&[i64]>,
433 dtype: DType,
434) -> Result<usize, TensorError> {
435 let numel = checked_numel(shape)?;
436 if numel == 0 {
437 return Ok(0);
438 }
439 let itemsize = dtype_size(dtype);
440 let Some(strides) = strides else {
441 return numel.checked_mul(itemsize).ok_or(TensorError::Overflow);
442 };
443 let mut last_element = 0usize;
444 for (&dim, &stride) in shape.iter().zip(strides) {
445 let span = (dim as usize - 1)
446 .checked_mul(stride as usize)
447 .ok_or(TensorError::Overflow)?;
448 last_element = last_element
449 .checked_add(span)
450 .ok_or(TensorError::Overflow)?;
451 }
452 last_element
453 .checked_add(1)
454 .ok_or(TensorError::Overflow)?
455 .checked_mul(itemsize)
456 .ok_or(TensorError::Overflow)
457}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462
463 fn f32_bytes(values: &[f32]) -> Vec<u8> {
464 values.iter().flat_map(|v| v.to_le_bytes()).collect()
465 }
466
467 #[test]
468 fn test_from_vec_adopts_and_validates_length() {
469 let tensor = Tensor::from_vec(f32_bytes(&[1.0, 2.0, 3.0]), vec![3], DType::Float32)
470 .expect("valid tensor");
471 assert_eq!(tensor.shape(), &[3]);
472 assert_eq!(tensor.numel(), 3);
473 assert_eq!(tensor.nbytes(), 12);
474 assert!(tensor.is_contiguous());
475 assert_eq!(tensor.strides(), None);
476 assert_eq!(tensor.effective_strides().as_ref(), &[1]);
477 assert_eq!(tensor.device(), Device::Cpu);
478
479 assert_eq!(
480 Tensor::from_vec(vec![0u8; 11], vec![3], DType::Float32),
481 Err(TensorError::ByteLengthMismatch {
482 expected: 12,
483 actual: 11
484 })
485 );
486 }
487
488 #[test]
489 fn test_constructor_rejects_invalid_inputs() {
490 assert_eq!(
491 Tensor::from_vec(vec![], vec![2], DType::Unspecified),
492 Err(TensorError::UnspecifiedDtype)
493 );
494 assert_eq!(
495 Tensor::from_vec(vec![], vec![-1], DType::Float32),
496 Err(TensorError::NegativeDim(-1))
497 );
498 assert_eq!(
499 Tensor::from_storage(
500 Storage::zeroed(8),
501 DType::Float32,
502 vec![2],
503 Some(vec![1, 1]),
504 0
505 ),
506 Err(TensorError::StrideRankMismatch {
507 strides: 2,
508 shape: 1
509 })
510 );
511 assert_eq!(
512 Tensor::from_storage(
513 Storage::zeroed(8),
514 DType::Float32,
515 vec![2],
516 Some(vec![-1]),
517 0
518 ),
519 Err(TensorError::NegativeStride(-1))
520 );
521 assert_eq!(
522 Tensor::from_storage(Storage::zeroed(8), DType::Float32, vec![3], None, 0),
523 Err(TensorError::OutOfBounds {
524 required: 12,
525 byte_offset: 0,
526 available: 8
527 })
528 );
529 assert_eq!(
531 Tensor::from_storage(
532 Storage::zeroed(12),
533 DType::Float32,
534 vec![2],
535 Some(vec![3]),
536 0
537 ),
538 Err(TensorError::OutOfBounds {
539 required: 16,
540 byte_offset: 0,
541 available: 12
542 })
543 );
544 assert_eq!(
545 Tensor::from_vec(vec![], vec![i64::MAX, i64::MAX], DType::Float32),
546 Err(TensorError::Overflow)
547 );
548 }
549
550 #[test]
551 fn test_zeros_is_aligned_and_zero_filled() {
552 let tensor = Tensor::zeros(&[4, 4], DType::Int32).expect("valid tensor");
553 assert_eq!(tensor.nbytes(), 64);
554 assert!(tensor.to_contiguous_bytes().iter().all(|&b| b == 0));
555 assert_eq!(tensor.storage().as_slice().as_ptr() as usize % 64, 0);
556 }
557
558 #[test]
559 fn test_scalar_tensor() {
560 let tensor =
561 Tensor::from_slice(&1.0f64.to_le_bytes(), &[], DType::Float64).expect("valid tensor");
562 assert_eq!(tensor.shape(), &[] as &[i64]);
563 assert_eq!(tensor.numel(), 1);
564 assert_eq!(tensor.to_contiguous_bytes().as_ref(), 1.0f64.to_le_bytes());
565 }
566
567 #[test]
568 fn test_reshape_contiguous_is_view() {
569 let tensor = Tensor::from_slice(
570 &f32_bytes(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]),
571 &[2, 3],
572 DType::Float32,
573 )
574 .expect("valid tensor");
575 let reshaped = tensor.reshape(&[3, 2]).expect("valid reshape");
576 assert!(reshaped.storage().ptr_eq(tensor.storage()));
577 assert_eq!(reshaped.shape(), &[3, 2]);
578 assert_eq!(reshaped.byte_offset(), tensor.byte_offset());
579 assert_eq!(
580 tensor.reshape(&[4, 2]),
581 Err(TensorError::NumelMismatch { from: 6, to: 8 })
582 );
583 }
584
585 #[test]
586 fn test_reshape_infers_one_dimension() {
587 let tensor = Tensor::zeros(&[2, 3, 4], DType::Uint8).expect("valid tensor");
588
589 let inferred = tensor.reshape(&[2, -1, 3]).expect("valid reshape");
590 assert_eq!(inferred.shape(), &[2, 4, 3]);
591 assert!(inferred.storage().ptr_eq(tensor.storage()));
592
593 let flat = tensor.reshape(&[-1]).expect("valid reshape");
594 assert_eq!(flat.shape(), &[24]);
595
596 assert_eq!(
597 tensor.reshape(&[-1, -1]),
598 Err(TensorError::AmbiguousReshape)
599 );
600 assert_eq!(
601 tensor.reshape(&[-1, 5]),
602 Err(TensorError::NumelMismatch { from: 24, to: 5 })
603 );
604 assert_eq!(tensor.reshape(&[-2, 4]), Err(TensorError::NegativeDim(-2)));
605 }
606
607 #[test]
608 fn test_reshape_inference_on_empty_tensors() {
609 let empty = Tensor::zeros(&[0, 3], DType::Float32).expect("valid tensor");
610
611 let inferred = empty.reshape(&[-1, 3]).expect("valid reshape");
613 assert_eq!(inferred.shape(), &[0, 3]);
614
615 assert_eq!(
617 empty.reshape(&[0, -1]),
618 Err(TensorError::NumelMismatch { from: 0, to: 0 })
619 );
620 }
621
622 #[test]
623 fn test_reshape_strided_copies() {
624 let storage = Storage::from_slice(&f32_bytes(&[1.0, 3.0, 2.0, 4.0]));
627 let tensor = Tensor::from_storage(storage, DType::Float32, vec![2, 2], Some(vec![1, 2]), 0)
628 .expect("valid tensor");
629 assert!(!tensor.is_contiguous());
630
631 let reshaped = tensor.reshape(&[4]).expect("valid reshape");
632 assert!(!reshaped.storage().ptr_eq(tensor.storage()));
633 assert!(reshaped.is_contiguous());
634 assert_eq!(
635 reshaped.to_contiguous_bytes().as_ref(),
636 f32_bytes(&[1.0, 2.0, 3.0, 4.0]).as_slice()
637 );
638 }
639
640 #[test]
641 fn test_to_contiguous_bytes_borrows_when_contiguous() {
642 let tensor =
643 Tensor::from_slice(&f32_bytes(&[1.0, 2.0]), &[2], DType::Float32).expect("valid");
644 assert!(matches!(tensor.to_contiguous_bytes(), Cow::Borrowed(_)));
645
646 let storage = Storage::from_slice(&f32_bytes(&[1.0, 2.0, 3.0, 4.0]));
647 let strided = Tensor::from_storage(storage, DType::Float32, vec![2], Some(vec![2]), 0)
648 .expect("valid tensor");
649 assert!(!strided.is_contiguous());
650 let gathered = strided.to_contiguous_bytes();
651 assert!(matches!(gathered, Cow::Owned(_)));
652 assert_eq!(gathered.as_ref(), f32_bytes(&[1.0, 3.0]).as_slice());
653 }
654
655 #[test]
656 fn test_strided_gather_multi_dimensional() {
657 let values: Vec<f32> = (0..12).map(|v| v as f32).collect();
659 let storage = Storage::from_slice(&f32_bytes(&values));
660 let view = Tensor::from_storage(storage, DType::Float32, vec![3, 2], Some(vec![4, 2]), 0)
661 .expect("valid tensor");
662 assert_eq!(
663 view.to_contiguous_bytes().as_ref(),
664 f32_bytes(&[0.0, 2.0, 4.0, 6.0, 8.0, 10.0]).as_slice()
665 );
666 }
667
668 #[test]
669 fn test_stack_and_unstack_roundtrip() {
670 let tensors: Vec<Tensor> = (0..3)
671 .map(|i| {
672 Tensor::from_slice(
673 &f32_bytes(&[i as f32, i as f32 + 0.5]),
674 &[2],
675 DType::Float32,
676 )
677 .expect("valid tensor")
678 })
679 .collect();
680
681 let stacked = Tensor::stack(&tensors).expect("valid stack");
682 assert_eq!(stacked.shape(), &[3, 2]);
683 assert!(stacked.is_contiguous());
684 assert_eq!(stacked.storage().as_slice().as_ptr() as usize % 64, 0);
685
686 let views = stacked.unstack().expect("valid unstack");
687 assert_eq!(views.len(), 3);
688 for (index, (view, original)) in views.iter().zip(&tensors).enumerate() {
689 assert!(view.storage().ptr_eq(stacked.storage()));
690 assert_eq!(view.byte_offset(), index * 8);
691 assert_eq!(view, original);
692 }
693 }
694
695 #[test]
696 fn test_stack_rejects_empty_and_mismatched() {
697 assert_eq!(Tensor::stack(&[]), Err(TensorError::EmptyStack));
698
699 let a = Tensor::zeros(&[2], DType::Float32).expect("valid tensor");
700 let b = Tensor::zeros(&[3], DType::Float32).expect("valid tensor");
701 let c = Tensor::zeros(&[2], DType::Int32).expect("valid tensor");
702 assert_eq!(
703 Tensor::stack(&[a.clone(), b]),
704 Err(TensorError::StackMismatch { index: 1 })
705 );
706 assert_eq!(
707 Tensor::stack(&[a, c]),
708 Err(TensorError::StackMismatch { index: 1 })
709 );
710 }
711
712 #[test]
713 fn test_unstack_scalar_fails() {
714 let scalar = Tensor::zeros(&[], DType::Float32).expect("valid tensor");
715 assert_eq!(scalar.unstack(), Err(TensorError::UnstackScalar));
716 }
717
718 #[test]
719 fn test_partial_eq_is_logical() {
720 let storage = Storage::from_slice(&f32_bytes(&[1.0, 2.0, 3.0, 4.0]));
723 let strided = Tensor::from_storage(storage, DType::Float32, vec![2], Some(vec![2]), 0)
724 .expect("valid tensor");
725 let contiguous =
726 Tensor::from_slice(&f32_bytes(&[1.0, 3.0]), &[2], DType::Float32).expect("valid");
727 assert_eq!(strided, contiguous);
728
729 let other_dtype = Tensor::from_slice(&[0u8; 2], &[2], DType::Uint8).expect("valid");
730 let same_bytes = Tensor::from_slice(&[0u8; 2], &[2], DType::Int8).expect("valid");
731 assert_ne!(other_dtype, same_bytes);
732
733 let flat = Tensor::zeros(&[4], DType::Float32).expect("valid");
734 let square = Tensor::zeros(&[2, 2], DType::Float32).expect("valid");
735 assert_ne!(flat, square);
736 }
737
738 #[test]
739 fn test_view_with_byte_offset() {
740 let storage = Storage::from_slice(&f32_bytes(&[1.0, 2.0, 3.0, 4.0]));
741 let tail =
742 Tensor::from_storage(storage, DType::Float32, vec![2], None, 8).expect("valid tensor");
743 assert_eq!(
744 tail.to_contiguous_bytes().as_ref(),
745 f32_bytes(&[3.0, 4.0]).as_slice()
746 );
747 assert_eq!(tail.byte_offset(), 8);
748 }
749
750 #[test]
751 fn test_empty_view_offset_must_stay_inside_storage() {
752 let storage = Storage::zeroed(8);
755 assert_eq!(
756 Tensor::from_storage(storage.clone(), DType::Float32, vec![0], None, 9),
757 Err(TensorError::OutOfBounds {
758 required: 0,
759 byte_offset: 9,
760 available: 8
761 })
762 );
763 let at_end = Tensor::from_storage(storage, DType::Float32, vec![0], None, 8)
765 .expect("valid empty view");
766 assert!(at_end.to_contiguous_bytes().is_empty());
767 }
768
769 #[test]
770 fn test_empty_tensor() {
771 let tensor = Tensor::zeros(&[0, 3], DType::Float32).expect("valid tensor");
772 assert_eq!(tensor.numel(), 0);
773 assert_eq!(tensor.nbytes(), 0);
774 assert!(tensor.is_contiguous());
775 assert!(tensor.to_contiguous_bytes().is_empty());
776 let views = tensor.unstack().expect("valid unstack");
777 assert!(views.is_empty());
778 }
779
780 #[test]
781 fn test_contiguous_strides_table() {
782 assert_eq!(contiguous_strides(&[]), Vec::<i64>::new());
783 assert_eq!(contiguous_strides(&[5]), vec![1]);
784 assert_eq!(contiguous_strides(&[2, 3]), vec![3, 1]);
785 assert_eq!(contiguous_strides(&[2, 3, 4]), vec![12, 4, 1]);
786 assert_eq!(contiguous_strides(&[0, 3]), vec![3, 1]);
787 }
788
789 #[test]
790 fn test_explicit_contiguous_strides_detected() {
791 let storage = Storage::from_slice(&f32_bytes(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]));
792 let tensor = Tensor::from_storage(storage, DType::Float32, vec![2, 3], Some(vec![3, 1]), 0)
793 .expect("valid tensor");
794 assert!(tensor.is_contiguous());
795 assert!(matches!(tensor.to_contiguous_bytes(), Cow::Borrowed(_)));
796 }
797}