1use crate::{DType, PixelFormat, Tensor, TensorMemory, TensorTrait};
5use half::f16;
6use std::fmt;
7
8#[non_exhaustive]
10pub enum TensorDyn {
11 U8(Tensor<u8>),
13 I8(Tensor<i8>),
15 U16(Tensor<u16>),
17 I16(Tensor<i16>),
19 U32(Tensor<u32>),
21 I32(Tensor<i32>),
23 U64(Tensor<u64>),
25 I64(Tensor<i64>),
27 F16(Tensor<f16>),
29 F32(Tensor<f32>),
31 F64(Tensor<f64>),
33}
34
35macro_rules! dispatch {
37 ($self:expr, $method:ident $(, $arg:expr)*) => {
38 match $self {
39 TensorDyn::U8(t) => t.$method($($arg),*),
40 TensorDyn::I8(t) => t.$method($($arg),*),
41 TensorDyn::U16(t) => t.$method($($arg),*),
42 TensorDyn::I16(t) => t.$method($($arg),*),
43 TensorDyn::U32(t) => t.$method($($arg),*),
44 TensorDyn::I32(t) => t.$method($($arg),*),
45 TensorDyn::U64(t) => t.$method($($arg),*),
46 TensorDyn::I64(t) => t.$method($($arg),*),
47 TensorDyn::F16(t) => t.$method($($arg),*),
48 TensorDyn::F32(t) => t.$method($($arg),*),
49 TensorDyn::F64(t) => t.$method($($arg),*),
50 }
51 };
52}
53
54macro_rules! downcast_methods {
56 ($variant:ident, $ty:ty, $as_name:ident, $as_mut_name:ident, $into_name:ident) => {
57 pub fn $as_name(&self) -> Option<&Tensor<$ty>> {
59 match self {
60 Self::$variant(t) => Some(t),
61 _ => None,
62 }
63 }
64
65 pub fn $as_mut_name(&mut self) -> Option<&mut Tensor<$ty>> {
67 match self {
68 Self::$variant(t) => Some(t),
69 _ => None,
70 }
71 }
72
73 #[allow(clippy::result_large_err)]
76 pub fn $into_name(self) -> Result<Tensor<$ty>, Self> {
77 match self {
78 Self::$variant(t) => Ok(t),
79 other => Err(other),
80 }
81 }
82 };
83}
84
85impl TensorDyn {
86 pub fn dtype(&self) -> DType {
88 match self {
89 Self::U8(_) => DType::U8,
90 Self::I8(_) => DType::I8,
91 Self::U16(_) => DType::U16,
92 Self::I16(_) => DType::I16,
93 Self::U32(_) => DType::U32,
94 Self::I32(_) => DType::I32,
95 Self::U64(_) => DType::U64,
96 Self::I64(_) => DType::I64,
97 Self::F16(_) => DType::F16,
98 Self::F32(_) => DType::F32,
99 Self::F64(_) => DType::F64,
100 }
101 }
102
103 pub fn shape(&self) -> &[usize] {
105 dispatch!(self, shape)
106 }
107
108 pub fn name(&self) -> String {
110 dispatch!(self, name)
111 }
112
113 pub fn format(&self) -> Option<PixelFormat> {
115 dispatch!(self, format)
116 }
117
118 pub fn width(&self) -> Option<usize> {
120 dispatch!(self, width)
121 }
122
123 pub fn height(&self) -> Option<usize> {
125 dispatch!(self, height)
126 }
127
128 pub fn size(&self) -> usize {
130 dispatch!(self, size)
131 }
132
133 pub fn memory(&self) -> TensorMemory {
135 dispatch!(self, memory)
136 }
137
138 pub fn reshape(&mut self, shape: &[usize]) -> crate::Result<()> {
140 dispatch!(self, reshape, shape)
141 }
142
143 pub fn set_format(&mut self, format: PixelFormat) -> crate::Result<()> {
161 dispatch!(self, set_format, format)
162 }
163
164 pub fn with_format(mut self, format: PixelFormat) -> crate::Result<Self> {
181 self.set_format(format)?;
182 Ok(self)
183 }
184
185 pub fn row_stride(&self) -> Option<usize> {
187 dispatch!(self, row_stride)
188 }
189
190 pub fn effective_row_stride(&self) -> Option<usize> {
192 dispatch!(self, effective_row_stride)
193 }
194
195 pub fn set_row_stride(&mut self, stride: usize) -> crate::Result<()> {
201 dispatch!(self, set_row_stride, stride)
202 }
203
204 pub fn with_row_stride(mut self, stride: usize) -> crate::Result<Self> {
206 self.set_row_stride(stride)?;
207 Ok(self)
208 }
209
210 pub fn plane_offset(&self) -> Option<usize> {
212 dispatch!(self, plane_offset)
213 }
214
215 pub fn set_plane_offset(&mut self, offset: usize) {
217 dispatch!(self, set_plane_offset, offset)
218 }
219
220 pub fn with_plane_offset(mut self, offset: usize) -> Self {
222 self.set_plane_offset(offset);
223 self
224 }
225
226 pub fn quantization(&self) -> Option<&crate::Quantization> {
230 match self {
231 Self::U8(t) => t.quantization(),
232 Self::I8(t) => t.quantization(),
233 Self::U16(t) => t.quantization(),
234 Self::I16(t) => t.quantization(),
235 Self::U32(t) => t.quantization(),
236 Self::I32(t) => t.quantization(),
237 Self::U64(t) => t.quantization(),
238 Self::I64(t) => t.quantization(),
239 Self::F16(_) | Self::F32(_) | Self::F64(_) => None,
240 }
241 }
242
243 pub fn set_quantization(&mut self, q: crate::Quantization) -> crate::Result<()> {
247 match self {
248 Self::U8(t) => t.set_quantization(q),
249 Self::I8(t) => t.set_quantization(q),
250 Self::U16(t) => t.set_quantization(q),
251 Self::I16(t) => t.set_quantization(q),
252 Self::U32(t) => t.set_quantization(q),
253 Self::I32(t) => t.set_quantization(q),
254 Self::U64(t) => t.set_quantization(q),
255 Self::I64(t) => t.set_quantization(q),
256 Self::F16(_) | Self::F32(_) | Self::F64(_) => Err(crate::Error::QuantizationInvalid {
257 field: "dtype_is_integer",
258 expected: "integer tensor dtype (u8/i8/u16/i16/u32/i32/u64/i64)".to_string(),
259 got: format!("{:?}", self.dtype()),
260 }),
261 }
262 }
263
264 pub fn with_quantization(mut self, q: crate::Quantization) -> crate::Result<Self> {
267 self.set_quantization(q)?;
268 Ok(self)
269 }
270
271 pub fn clear_quantization(&mut self) {
273 match self {
274 Self::U8(t) => t.clear_quantization(),
275 Self::I8(t) => t.clear_quantization(),
276 Self::U16(t) => t.clear_quantization(),
277 Self::I16(t) => t.clear_quantization(),
278 Self::U32(t) => t.clear_quantization(),
279 Self::I32(t) => t.clear_quantization(),
280 Self::U64(t) => t.clear_quantization(),
281 Self::I64(t) => t.clear_quantization(),
282 Self::F16(_) | Self::F32(_) | Self::F64(_) => {}
283 }
284 }
285
286 #[cfg(unix)]
288 pub fn clone_fd(&self) -> crate::Result<std::os::fd::OwnedFd> {
289 dispatch!(self, clone_fd)
290 }
291
292 #[cfg(target_os = "linux")]
303 pub fn dmabuf_clone(&self) -> crate::Result<std::os::fd::OwnedFd> {
304 if self.memory() != TensorMemory::Dma {
305 return Err(crate::Error::NotImplemented(format!(
306 "dmabuf_clone requires DMA-backed tensor, got {:?}",
307 self.memory()
308 )));
309 }
310 self.clone_fd()
311 }
312
313 #[cfg(target_os = "linux")]
324 pub fn dmabuf(&self) -> crate::Result<std::os::fd::BorrowedFd<'_>> {
325 dispatch!(self, dmabuf)
326 }
327
328 pub fn is_multiplane(&self) -> bool {
330 dispatch!(self, is_multiplane)
331 }
332
333 pub fn buffer_identity(&self) -> &crate::BufferIdentity {
343 dispatch!(self, buffer_identity)
344 }
345
346 pub fn aliases(&self, other: &Self) -> bool {
364 if self.buffer_identity().id() == other.buffer_identity().id() {
365 return true;
366 }
367 if self.memory() != other.memory() {
368 return false;
369 }
370 #[cfg(target_os = "linux")]
371 if self.memory() == TensorMemory::Dma {
372 use std::os::fd::AsRawFd;
373 if let (Ok(a), Ok(b)) = (self.dmabuf(), other.dmabuf()) {
374 return a.as_raw_fd() == b.as_raw_fd();
375 }
376 }
377 false
378 }
379
380 downcast_methods!(U8, u8, as_u8, as_u8_mut, into_u8);
383 downcast_methods!(I8, i8, as_i8, as_i8_mut, into_i8);
384 downcast_methods!(U16, u16, as_u16, as_u16_mut, into_u16);
385 downcast_methods!(I16, i16, as_i16, as_i16_mut, into_i16);
386 downcast_methods!(U32, u32, as_u32, as_u32_mut, into_u32);
387 downcast_methods!(I32, i32, as_i32, as_i32_mut, into_i32);
388 downcast_methods!(U64, u64, as_u64, as_u64_mut, into_u64);
389 downcast_methods!(I64, i64, as_i64, as_i64_mut, into_i64);
390 downcast_methods!(F16, f16, as_f16, as_f16_mut, into_f16);
391 downcast_methods!(F32, f32, as_f32, as_f32_mut, into_f32);
392 downcast_methods!(F64, f64, as_f64, as_f64_mut, into_f64);
393
394 pub fn new(
396 shape: &[usize],
397 dtype: DType,
398 memory: Option<TensorMemory>,
399 name: Option<&str>,
400 ) -> crate::Result<Self> {
401 match dtype {
402 DType::U8 => Tensor::<u8>::new(shape, memory, name).map(Self::U8),
403 DType::I8 => Tensor::<i8>::new(shape, memory, name).map(Self::I8),
404 DType::U16 => Tensor::<u16>::new(shape, memory, name).map(Self::U16),
405 DType::I16 => Tensor::<i16>::new(shape, memory, name).map(Self::I16),
406 DType::U32 => Tensor::<u32>::new(shape, memory, name).map(Self::U32),
407 DType::I32 => Tensor::<i32>::new(shape, memory, name).map(Self::I32),
408 DType::U64 => Tensor::<u64>::new(shape, memory, name).map(Self::U64),
409 DType::I64 => Tensor::<i64>::new(shape, memory, name).map(Self::I64),
410 DType::F16 => Tensor::<f16>::new(shape, memory, name).map(Self::F16),
411 DType::F32 => Tensor::<f32>::new(shape, memory, name).map(Self::F32),
412 DType::F64 => Tensor::<f64>::new(shape, memory, name).map(Self::F64),
413 }
414 }
415
416 #[cfg(unix)]
418 pub fn from_fd(
419 fd: std::os::fd::OwnedFd,
420 shape: &[usize],
421 dtype: DType,
422 name: Option<&str>,
423 ) -> crate::Result<Self> {
424 match dtype {
425 DType::U8 => Tensor::<u8>::from_fd(fd, shape, name).map(Self::U8),
426 DType::I8 => Tensor::<i8>::from_fd(fd, shape, name).map(Self::I8),
427 DType::U16 => Tensor::<u16>::from_fd(fd, shape, name).map(Self::U16),
428 DType::I16 => Tensor::<i16>::from_fd(fd, shape, name).map(Self::I16),
429 DType::U32 => Tensor::<u32>::from_fd(fd, shape, name).map(Self::U32),
430 DType::I32 => Tensor::<i32>::from_fd(fd, shape, name).map(Self::I32),
431 DType::U64 => Tensor::<u64>::from_fd(fd, shape, name).map(Self::U64),
432 DType::I64 => Tensor::<i64>::from_fd(fd, shape, name).map(Self::I64),
433 DType::F16 => Tensor::<f16>::from_fd(fd, shape, name).map(Self::F16),
434 DType::F32 => Tensor::<f32>::from_fd(fd, shape, name).map(Self::F32),
435 DType::F64 => Tensor::<f64>::from_fd(fd, shape, name).map(Self::F64),
436 }
437 }
438
439 #[cfg(target_os = "macos")]
447 pub unsafe fn from_iosurface(
448 surface_ref: *mut std::ffi::c_void,
449 shape: &[usize],
450 dtype: DType,
451 name: Option<&str>,
452 ) -> crate::Result<Self> {
453 unsafe {
454 match dtype {
455 DType::U8 => Tensor::<u8>::from_iosurface(surface_ref, shape, name).map(Self::U8),
456 DType::I8 => Tensor::<i8>::from_iosurface(surface_ref, shape, name).map(Self::I8),
457 DType::U16 => {
458 Tensor::<u16>::from_iosurface(surface_ref, shape, name).map(Self::U16)
459 }
460 DType::I16 => {
461 Tensor::<i16>::from_iosurface(surface_ref, shape, name).map(Self::I16)
462 }
463 DType::U32 => {
464 Tensor::<u32>::from_iosurface(surface_ref, shape, name).map(Self::U32)
465 }
466 DType::I32 => {
467 Tensor::<i32>::from_iosurface(surface_ref, shape, name).map(Self::I32)
468 }
469 DType::U64 => {
470 Tensor::<u64>::from_iosurface(surface_ref, shape, name).map(Self::U64)
471 }
472 DType::I64 => {
473 Tensor::<i64>::from_iosurface(surface_ref, shape, name).map(Self::I64)
474 }
475 DType::F16 => {
476 Tensor::<f16>::from_iosurface(surface_ref, shape, name).map(Self::F16)
477 }
478 DType::F32 => {
479 Tensor::<f32>::from_iosurface(surface_ref, shape, name).map(Self::F32)
480 }
481 DType::F64 => {
482 Tensor::<f64>::from_iosurface(surface_ref, shape, name).map(Self::F64)
483 }
484 }
485 }
486 }
487
488 #[cfg(target_os = "macos")]
491 pub fn iosurface_id(&self) -> Option<u32> {
492 dispatch!(self, iosurface_id)
493 }
494
495 #[cfg(target_os = "macos")]
499 pub fn iosurface_ref(&self) -> Option<*mut std::ffi::c_void> {
500 dispatch!(self, iosurface_ref)
501 }
502
503 pub fn image(
521 width: usize,
522 height: usize,
523 format: PixelFormat,
524 dtype: DType,
525 memory: Option<TensorMemory>,
526 ) -> crate::Result<Self> {
527 match dtype {
528 DType::U8 => Tensor::<u8>::image(width, height, format, memory).map(Self::U8),
529 DType::I8 => Tensor::<i8>::image(width, height, format, memory).map(Self::I8),
530 DType::U16 => Tensor::<u16>::image(width, height, format, memory).map(Self::U16),
531 DType::I16 => Tensor::<i16>::image(width, height, format, memory).map(Self::I16),
532 DType::U32 => Tensor::<u32>::image(width, height, format, memory).map(Self::U32),
533 DType::I32 => Tensor::<i32>::image(width, height, format, memory).map(Self::I32),
534 DType::U64 => Tensor::<u64>::image(width, height, format, memory).map(Self::U64),
535 DType::I64 => Tensor::<i64>::image(width, height, format, memory).map(Self::I64),
536 DType::F16 => Tensor::<f16>::image(width, height, format, memory).map(Self::F16),
537 DType::F32 => Tensor::<f32>::image(width, height, format, memory).map(Self::F32),
538 DType::F64 => Tensor::<f64>::image(width, height, format, memory).map(Self::F64),
539 }
540 }
541
542 pub fn image_with_stride(
568 width: usize,
569 height: usize,
570 format: PixelFormat,
571 dtype: DType,
572 row_stride_bytes: usize,
573 memory: Option<TensorMemory>,
574 ) -> crate::Result<Self> {
575 match dtype {
576 DType::U8 => {
577 Tensor::<u8>::image_with_stride(width, height, format, row_stride_bytes, memory)
578 .map(Self::U8)
579 }
580 DType::I8 => {
581 Tensor::<i8>::image_with_stride(width, height, format, row_stride_bytes, memory)
582 .map(Self::I8)
583 }
584 DType::U16 => {
585 Tensor::<u16>::image_with_stride(width, height, format, row_stride_bytes, memory)
586 .map(Self::U16)
587 }
588 DType::I16 => {
589 Tensor::<i16>::image_with_stride(width, height, format, row_stride_bytes, memory)
590 .map(Self::I16)
591 }
592 DType::U32 => {
593 Tensor::<u32>::image_with_stride(width, height, format, row_stride_bytes, memory)
594 .map(Self::U32)
595 }
596 DType::I32 => {
597 Tensor::<i32>::image_with_stride(width, height, format, row_stride_bytes, memory)
598 .map(Self::I32)
599 }
600 DType::U64 => {
601 Tensor::<u64>::image_with_stride(width, height, format, row_stride_bytes, memory)
602 .map(Self::U64)
603 }
604 DType::I64 => {
605 Tensor::<i64>::image_with_stride(width, height, format, row_stride_bytes, memory)
606 .map(Self::I64)
607 }
608 DType::F16 => {
609 Tensor::<f16>::image_with_stride(width, height, format, row_stride_bytes, memory)
610 .map(Self::F16)
611 }
612 DType::F32 => {
613 Tensor::<f32>::image_with_stride(width, height, format, row_stride_bytes, memory)
614 .map(Self::F32)
615 }
616 DType::F64 => {
617 Tensor::<f64>::image_with_stride(width, height, format, row_stride_bytes, memory)
618 .map(Self::F64)
619 }
620 }
621 }
622}
623
624impl From<Tensor<u8>> for TensorDyn {
627 fn from(t: Tensor<u8>) -> Self {
628 Self::U8(t)
629 }
630}
631
632impl From<Tensor<i8>> for TensorDyn {
633 fn from(t: Tensor<i8>) -> Self {
634 Self::I8(t)
635 }
636}
637
638impl From<Tensor<u16>> for TensorDyn {
639 fn from(t: Tensor<u16>) -> Self {
640 Self::U16(t)
641 }
642}
643
644impl From<Tensor<i16>> for TensorDyn {
645 fn from(t: Tensor<i16>) -> Self {
646 Self::I16(t)
647 }
648}
649
650impl From<Tensor<u32>> for TensorDyn {
651 fn from(t: Tensor<u32>) -> Self {
652 Self::U32(t)
653 }
654}
655
656impl From<Tensor<i32>> for TensorDyn {
657 fn from(t: Tensor<i32>) -> Self {
658 Self::I32(t)
659 }
660}
661
662impl From<Tensor<u64>> for TensorDyn {
663 fn from(t: Tensor<u64>) -> Self {
664 Self::U64(t)
665 }
666}
667
668impl From<Tensor<i64>> for TensorDyn {
669 fn from(t: Tensor<i64>) -> Self {
670 Self::I64(t)
671 }
672}
673
674impl From<Tensor<f16>> for TensorDyn {
675 fn from(t: Tensor<f16>) -> Self {
676 Self::F16(t)
677 }
678}
679
680impl From<Tensor<f32>> for TensorDyn {
681 fn from(t: Tensor<f32>) -> Self {
682 Self::F32(t)
683 }
684}
685
686impl From<Tensor<f64>> for TensorDyn {
687 fn from(t: Tensor<f64>) -> Self {
688 Self::F64(t)
689 }
690}
691
692impl fmt::Debug for TensorDyn {
693 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
694 dispatch!(self, fmt, f)
695 }
696}
697
698#[cfg(test)]
699mod tests {
700 use super::*;
701
702 #[test]
703 fn from_typed_tensor() {
704 let t = Tensor::<u8>::new(&[10], None, None).unwrap();
705 let dyn_t: TensorDyn = t.into();
706 assert_eq!(dyn_t.dtype(), DType::U8);
707 assert_eq!(dyn_t.shape(), &[10]);
708 }
709
710 #[test]
711 fn downcast_ref() {
712 let t = Tensor::<u8>::new(&[10], None, None).unwrap();
713 let dyn_t: TensorDyn = t.into();
714 assert!(dyn_t.as_u8().is_some());
715 assert!(dyn_t.as_i8().is_none());
716 }
717
718 #[test]
719 fn downcast_into() {
720 let t = Tensor::<u8>::new(&[10], None, None).unwrap();
721 let dyn_t: TensorDyn = t.into();
722 let back = dyn_t.into_u8().unwrap();
723 assert_eq!(back.shape(), &[10]);
724 }
725
726 #[test]
727 fn image_accessors() {
728 let t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
729 let dyn_t: TensorDyn = t.into();
730 assert_eq!(dyn_t.format(), Some(PixelFormat::Rgba));
731 assert_eq!(dyn_t.width(), Some(640));
732 assert_eq!(dyn_t.height(), Some(480));
733 assert!(!dyn_t.is_multiplane());
734 }
735
736 #[test]
737 fn image_constructor() {
738 let dyn_t = TensorDyn::image(640, 480, PixelFormat::Rgb, DType::U8, None).unwrap();
739 assert_eq!(dyn_t.dtype(), DType::U8);
740 assert_eq!(dyn_t.format(), Some(PixelFormat::Rgb));
741 assert_eq!(dyn_t.width(), Some(640));
742 }
743
744 #[test]
745 fn image_constructor_i8() {
746 let dyn_t = TensorDyn::image(640, 480, PixelFormat::Rgb, DType::I8, None).unwrap();
747 assert_eq!(dyn_t.dtype(), DType::I8);
748 assert_eq!(dyn_t.format(), Some(PixelFormat::Rgb));
749 }
750
751 #[test]
752 fn set_format_packed() {
753 let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
754 assert_eq!(t.format(), None);
755 t.set_format(PixelFormat::Rgb).unwrap();
756 assert_eq!(t.format(), Some(PixelFormat::Rgb));
757 assert_eq!(t.width(), Some(640));
758 assert_eq!(t.height(), Some(480));
759 }
760
761 #[test]
762 fn set_format_planar() {
763 let mut t = TensorDyn::new(&[3, 480, 640], DType::U8, None, None).unwrap();
764 t.set_format(PixelFormat::PlanarRgb).unwrap();
765 assert_eq!(t.format(), Some(PixelFormat::PlanarRgb));
766 assert_eq!(t.width(), Some(640));
767 assert_eq!(t.height(), Some(480));
768 }
769
770 #[test]
771 fn set_format_rejects_wrong_shape() {
772 let mut t = TensorDyn::new(&[480, 640, 4], DType::U8, None, None).unwrap();
773 assert!(t.set_format(PixelFormat::Rgb).is_err());
774 }
775
776 #[test]
777 fn with_format_builder() {
778 let t = TensorDyn::new(&[480, 640, 4], DType::U8, None, None)
779 .unwrap()
780 .with_format(PixelFormat::Rgba)
781 .unwrap();
782 assert_eq!(t.format(), Some(PixelFormat::Rgba));
783 assert_eq!(t.width(), Some(640));
784 assert_eq!(t.height(), Some(480));
785 }
786
787 #[cfg(target_os = "linux")]
788 #[test]
789 fn dmabuf_clone_mem_tensor_fails() {
790 let t = TensorDyn::new(&[480, 640, 3], DType::U8, Some(TensorMemory::Mem), None).unwrap();
791 assert_eq!(t.memory(), TensorMemory::Mem);
792 assert!(t.dmabuf_clone().is_err());
793 }
794
795 #[cfg(target_os = "linux")]
796 #[test]
797 fn dmabuf_mem_tensor_fails() {
798 let t = TensorDyn::new(&[480, 640, 3], DType::U8, Some(TensorMemory::Mem), None).unwrap();
799 assert!(t.dmabuf().is_err());
800 }
801
802 #[test]
803 fn set_format_semi_planar_nv12() {
804 let mut t = TensorDyn::new(&[720, 640], DType::U8, Some(TensorMemory::Mem), None).unwrap();
806 t.set_format(PixelFormat::Nv12).unwrap();
807 assert_eq!(t.format(), Some(PixelFormat::Nv12));
808 assert_eq!(t.width(), Some(640));
809 assert_eq!(t.height(), Some(480));
810 }
811
812 #[test]
813 fn set_format_semi_planar_nv16() {
814 let mut t = TensorDyn::new(&[960, 640], DType::U8, Some(TensorMemory::Mem), None).unwrap();
816 t.set_format(PixelFormat::Nv16).unwrap();
817 assert_eq!(t.format(), Some(PixelFormat::Nv16));
818 assert_eq!(t.width(), Some(640));
819 assert_eq!(t.height(), Some(480));
820 }
821
822 #[test]
823 fn with_format_rejects_wrong_shape() {
824 let result = TensorDyn::new(&[480, 640, 4], DType::U8, None, None)
825 .unwrap()
826 .with_format(PixelFormat::Rgb);
827 assert!(result.is_err());
828 }
829
830 #[test]
831 fn set_format_preserved_after_rejection() {
832 let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
833 t.set_format(PixelFormat::Rgb).unwrap();
834 assert_eq!(t.format(), Some(PixelFormat::Rgb));
835
836 assert!(t.set_format(PixelFormat::Rgba).is_err());
838
839 assert_eq!(t.format(), Some(PixelFormat::Rgb));
841 }
842
843 #[test]
844 fn set_format_idempotent() {
845 let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
846 t.set_format(PixelFormat::Rgb).unwrap();
847 t.set_format(PixelFormat::Rgb).unwrap();
848 assert_eq!(t.format(), Some(PixelFormat::Rgb));
849 assert_eq!(t.width(), Some(640));
850 assert_eq!(t.height(), Some(480));
851 }
852
853 #[test]
856 fn set_row_stride_valid() {
857 let mut t = TensorDyn::image(100, 100, PixelFormat::Rgba, DType::U8, None).unwrap();
859 t.set_row_stride(512).unwrap();
860 assert_eq!(t.row_stride(), Some(512));
861 assert_eq!(t.effective_row_stride(), Some(512));
862 }
863
864 #[test]
865 fn set_row_stride_equals_min() {
866 let mut t = TensorDyn::image(100, 100, PixelFormat::Rgb, DType::U8, None).unwrap();
868 t.set_row_stride(300).unwrap();
869 assert_eq!(t.row_stride(), Some(300));
870 }
871
872 #[test]
873 fn set_row_stride_too_small() {
874 let mut t = TensorDyn::image(100, 100, PixelFormat::Rgba, DType::U8, None).unwrap();
876 assert!(t.set_row_stride(300).is_err());
877 assert_eq!(t.row_stride(), None);
878 }
879
880 #[test]
881 fn set_row_stride_zero() {
882 let mut t = TensorDyn::image(100, 100, PixelFormat::Rgb, DType::U8, None).unwrap();
883 assert!(t.set_row_stride(0).is_err());
884 }
885
886 #[test]
887 fn set_row_stride_requires_format() {
888 let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
889 assert!(t.set_row_stride(2048).is_err());
890 }
891
892 #[test]
893 fn effective_row_stride_without_stride() {
894 let t = TensorDyn::image(100, 100, PixelFormat::Rgb, DType::U8, None).unwrap();
895 assert_eq!(t.row_stride(), None);
896 assert_eq!(t.effective_row_stride(), Some(300)); }
898
899 #[test]
900 fn effective_row_stride_no_format() {
901 let t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
902 assert_eq!(t.effective_row_stride(), None);
903 }
904
905 #[test]
906 fn with_row_stride_builder() {
907 let t = TensorDyn::image(100, 100, PixelFormat::Rgba, DType::U8, None)
908 .unwrap()
909 .with_row_stride(512)
910 .unwrap();
911 assert_eq!(t.row_stride(), Some(512));
912 assert_eq!(t.effective_row_stride(), Some(512));
913 }
914
915 #[test]
916 fn with_row_stride_rejects_small() {
917 let result = TensorDyn::image(100, 100, PixelFormat::Rgba, DType::U8, None)
918 .unwrap()
919 .with_row_stride(200);
920 assert!(result.is_err());
921 }
922
923 #[test]
924 fn set_format_clears_row_stride() {
925 let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
926 t.set_format(PixelFormat::Rgb).unwrap();
927 t.set_row_stride(2048).unwrap();
928 assert_eq!(t.row_stride(), Some(2048));
929
930 let _ = t.set_format(PixelFormat::Bgra);
932 assert_eq!(t.row_stride(), Some(2048));
933
934 t.set_format(PixelFormat::Rgb).unwrap();
936 assert_eq!(t.row_stride(), Some(2048));
937
938 t.reshape(&[480 * 640 * 3]).unwrap();
940 assert_eq!(t.row_stride(), None);
941 assert_eq!(t.format(), None);
942 }
943
944 #[test]
945 fn set_format_different_compatible_clears_stride() {
946 let mut t = TensorDyn::new(&[480, 640, 4], DType::U8, None, None).unwrap();
949 t.set_format(PixelFormat::Rgba).unwrap();
950 t.set_row_stride(4096).unwrap();
951 assert_eq!(t.row_stride(), Some(4096));
952
953 t.set_format(PixelFormat::Bgra).unwrap();
955 assert_eq!(t.format(), Some(PixelFormat::Bgra));
956 assert_eq!(t.row_stride(), None);
957 }
958
959 #[test]
960 fn set_format_same_preserves_stride() {
961 let mut t = TensorDyn::image(100, 100, PixelFormat::Rgb, DType::U8, None).unwrap();
962 t.set_row_stride(512).unwrap();
963 t.set_format(PixelFormat::Rgb).unwrap();
965 assert_eq!(t.row_stride(), Some(512));
966 }
967
968 #[test]
969 fn effective_row_stride_planar() {
970 let t = TensorDyn::image(640, 480, PixelFormat::PlanarRgb, DType::U8, None).unwrap();
971 assert_eq!(t.effective_row_stride(), Some(640)); }
973
974 #[test]
975 fn effective_row_stride_nv12() {
976 let t = TensorDyn::image(640, 480, PixelFormat::Nv12, DType::U8, None).unwrap();
977 assert_eq!(t.effective_row_stride(), Some(640)); }
979
980 #[test]
981 fn map_rejects_strided_tensor() {
982 let mut t =
983 Tensor::<u8>::image(100, 100, PixelFormat::Rgba, Some(TensorMemory::Mem)).unwrap();
984 assert!(t.map().is_ok());
986 t.set_row_stride(512).unwrap();
988 let err = t.map();
989 assert!(err.is_err());
990 }
991
992 #[test]
995 fn plane_offset_default_none() {
996 let t = TensorDyn::image(100, 100, PixelFormat::Rgba, DType::U8, None).unwrap();
997 assert_eq!(t.plane_offset(), None);
998 }
999
1000 #[test]
1001 fn set_plane_offset_basic() {
1002 let mut t = TensorDyn::image(100, 100, PixelFormat::Rgba, DType::U8, None).unwrap();
1003 t.set_plane_offset(4096);
1004 assert_eq!(t.plane_offset(), Some(4096));
1005 }
1006
1007 #[test]
1008 fn set_plane_offset_zero() {
1009 let mut t = TensorDyn::image(100, 100, PixelFormat::Rgb, DType::U8, None).unwrap();
1010 t.set_plane_offset(0);
1011 assert_eq!(t.plane_offset(), Some(0));
1012 }
1013
1014 #[test]
1015 fn set_plane_offset_no_format() {
1016 let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
1018 t.set_plane_offset(4096);
1019 assert_eq!(t.plane_offset(), Some(4096));
1020 }
1021
1022 #[test]
1023 fn with_plane_offset_builder() {
1024 let t = TensorDyn::image(100, 100, PixelFormat::Rgba, DType::U8, None)
1025 .unwrap()
1026 .with_plane_offset(8192);
1027 assert_eq!(t.plane_offset(), Some(8192));
1028 }
1029
1030 #[test]
1031 fn set_format_clears_plane_offset() {
1032 let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
1033 t.set_format(PixelFormat::Rgb).unwrap();
1034 t.set_plane_offset(4096);
1035 assert_eq!(t.plane_offset(), Some(4096));
1036
1037 t.set_format(PixelFormat::Rgb).unwrap();
1039 assert_eq!(t.plane_offset(), Some(4096));
1040
1041 t.reshape(&[480 * 640 * 3]).unwrap();
1043 assert_eq!(t.plane_offset(), None);
1044 assert_eq!(t.format(), None);
1045 }
1046
1047 #[test]
1048 fn map_rejects_offset_tensor() {
1049 let mut t =
1050 Tensor::<u8>::image(100, 100, PixelFormat::Rgba, Some(TensorMemory::Mem)).unwrap();
1051 assert!(t.map().is_ok());
1053 t.set_plane_offset(4096);
1055 assert!(t.map().is_err());
1056 }
1057
1058 #[test]
1059 fn map_accepts_zero_offset_tensor() {
1060 let mut t =
1061 Tensor::<u8>::image(100, 100, PixelFormat::Rgba, Some(TensorMemory::Mem)).unwrap();
1062 t.set_plane_offset(0);
1063 assert!(t.map().is_ok());
1065 }
1066
1067 #[test]
1068 fn from_planes_propagates_plane_offset() {
1069 let mut luma =
1070 Tensor::<u8>::new(&[480, 640], Some(TensorMemory::Mem), Some("luma")).unwrap();
1071 luma.set_plane_offset(4096);
1072 let chroma =
1073 Tensor::<u8>::new(&[240, 640], Some(TensorMemory::Mem), Some("chroma")).unwrap();
1074 let combined = Tensor::<u8>::from_planes(luma, chroma, PixelFormat::Nv12).unwrap();
1075 assert_eq!(combined.plane_offset(), Some(4096));
1076 }
1077}