Skip to main content

edgefirst_tensor/
tensor_dyn.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{DType, PixelFormat, Tensor, TensorMemory, TensorTrait};
5use half::f16;
6use std::fmt;
7
8/// Type-erased tensor. Wraps a `Tensor<T>` with runtime element type.
9#[non_exhaustive]
10pub enum TensorDyn {
11    /// Unsigned 8-bit integer tensor.
12    U8(Tensor<u8>),
13    /// Signed 8-bit integer tensor.
14    I8(Tensor<i8>),
15    /// Unsigned 16-bit integer tensor.
16    U16(Tensor<u16>),
17    /// Signed 16-bit integer tensor.
18    I16(Tensor<i16>),
19    /// Unsigned 32-bit integer tensor.
20    U32(Tensor<u32>),
21    /// Signed 32-bit integer tensor.
22    I32(Tensor<i32>),
23    /// Unsigned 64-bit integer tensor.
24    U64(Tensor<u64>),
25    /// Signed 64-bit integer tensor.
26    I64(Tensor<i64>),
27    /// 16-bit floating-point tensor.
28    F16(Tensor<f16>),
29    /// 32-bit floating-point tensor.
30    F32(Tensor<f32>),
31    /// 64-bit floating-point tensor.
32    F64(Tensor<f64>),
33}
34
35/// Dispatch a method call across all TensorDyn variants.
36macro_rules! dispatch {
37    ($self:expr, $method:ident $(, $arg:expr)*) => {
38        match $self {
39            TensorDyn::U8(t) => t.$method($($arg),*),
40            TensorDyn::I8(t) => t.$method($($arg),*),
41            TensorDyn::U16(t) => t.$method($($arg),*),
42            TensorDyn::I16(t) => t.$method($($arg),*),
43            TensorDyn::U32(t) => t.$method($($arg),*),
44            TensorDyn::I32(t) => t.$method($($arg),*),
45            TensorDyn::U64(t) => t.$method($($arg),*),
46            TensorDyn::I64(t) => t.$method($($arg),*),
47            TensorDyn::F16(t) => t.$method($($arg),*),
48            TensorDyn::F32(t) => t.$method($($arg),*),
49            TensorDyn::F64(t) => t.$method($($arg),*),
50        }
51    };
52}
53
54/// Generate the three downcast methods (ref, mut ref, owned) for one variant.
55macro_rules! downcast_methods {
56    ($variant:ident, $ty:ty, $as_name:ident, $as_mut_name:ident, $into_name:ident) => {
57        /// Returns a shared reference to the inner tensor if the type matches.
58        pub fn $as_name(&self) -> Option<&Tensor<$ty>> {
59            match self {
60                Self::$variant(t) => Some(t),
61                _ => None,
62            }
63        }
64
65        /// Returns a mutable reference to the inner tensor if the type matches.
66        pub fn $as_mut_name(&mut self) -> Option<&mut Tensor<$ty>> {
67            match self {
68                Self::$variant(t) => Some(t),
69                _ => None,
70            }
71        }
72
73        /// Unwraps the inner tensor if the type matches, otherwise returns `self` as `Err`.
74        /// The Err variant is necessarily large (returns the unconsumed TensorDyn).
75        #[allow(clippy::result_large_err)]
76        pub fn $into_name(self) -> Result<Tensor<$ty>, Self> {
77            match self {
78                Self::$variant(t) => Ok(t),
79                other => Err(other),
80            }
81        }
82    };
83}
84
85impl TensorDyn {
86    /// Return the runtime element type discriminant.
87    pub fn dtype(&self) -> DType {
88        match self {
89            Self::U8(_) => DType::U8,
90            Self::I8(_) => DType::I8,
91            Self::U16(_) => DType::U16,
92            Self::I16(_) => DType::I16,
93            Self::U32(_) => DType::U32,
94            Self::I32(_) => DType::I32,
95            Self::U64(_) => DType::U64,
96            Self::I64(_) => DType::I64,
97            Self::F16(_) => DType::F16,
98            Self::F32(_) => DType::F32,
99            Self::F64(_) => DType::F64,
100        }
101    }
102
103    /// Return the tensor shape.
104    pub fn shape(&self) -> &[usize] {
105        dispatch!(self, shape)
106    }
107
108    /// Return the tensor name.
109    pub fn name(&self) -> String {
110        dispatch!(self, name)
111    }
112
113    /// Return the pixel format (None if not an image tensor).
114    pub fn format(&self) -> Option<PixelFormat> {
115        dispatch!(self, format)
116    }
117
118    /// Return the image width (None if not an image tensor).
119    pub fn width(&self) -> Option<usize> {
120        dispatch!(self, width)
121    }
122
123    /// Return the image height (None if not an image tensor).
124    pub fn height(&self) -> Option<usize> {
125        dispatch!(self, height)
126    }
127
128    /// Return the total size of this tensor in bytes.
129    pub fn size(&self) -> usize {
130        dispatch!(self, size)
131    }
132
133    /// Return the memory allocation type.
134    pub fn memory(&self) -> TensorMemory {
135        dispatch!(self, memory)
136    }
137
138    /// Reshape this tensor. Total element count must remain the same.
139    pub fn reshape(&mut self, shape: &[usize]) -> crate::Result<()> {
140        dispatch!(self, reshape, shape)
141    }
142
143    /// Attach pixel format metadata to this tensor.
144    ///
145    /// Validates that the tensor's shape is compatible with the format's
146    /// layout (packed, planar, or semi-planar).
147    ///
148    /// # Arguments
149    ///
150    /// * `format` - The pixel format to attach
151    ///
152    /// # Returns
153    ///
154    /// `Ok(())` on success, with the format stored as metadata on the tensor.
155    ///
156    /// # Errors
157    ///
158    /// Returns `Error::InvalidShape` if the tensor shape doesn't match
159    /// the expected layout for the given format.
160    pub fn set_format(&mut self, format: PixelFormat) -> crate::Result<()> {
161        dispatch!(self, set_format, format)
162    }
163
164    /// Attach pixel format metadata, consuming and returning self.
165    ///
166    /// Enables builder-style chaining.
167    ///
168    /// # Arguments
169    ///
170    /// * `format` - The pixel format to attach
171    ///
172    /// # Returns
173    ///
174    /// The tensor with format metadata attached.
175    ///
176    /// # Errors
177    ///
178    /// Returns `Error::InvalidShape` if the tensor shape doesn't match
179    /// the expected layout for the given format.
180    pub fn with_format(mut self, format: PixelFormat) -> crate::Result<Self> {
181        self.set_format(format)?;
182        Ok(self)
183    }
184
185    /// Row stride in bytes (`None` = tightly packed).
186    pub fn row_stride(&self) -> Option<usize> {
187        dispatch!(self, row_stride)
188    }
189
190    /// Effective row stride: stored stride or computed from format and width.
191    pub fn effective_row_stride(&self) -> Option<usize> {
192        dispatch!(self, effective_row_stride)
193    }
194
195    /// Set the row stride in bytes for externally allocated buffers with
196    /// row padding.
197    ///
198    /// Must be called before the tensor is first used for rendering. The
199    /// format must be set before calling this method.
200    pub fn set_row_stride(&mut self, stride: usize) -> crate::Result<()> {
201        dispatch!(self, set_row_stride, stride)
202    }
203
204    /// Builder-style: set row stride, consuming and returning self.
205    pub fn with_row_stride(mut self, stride: usize) -> crate::Result<Self> {
206        self.set_row_stride(stride)?;
207        Ok(self)
208    }
209
210    /// Byte offset within the DMA-BUF where image data starts (`None` = 0).
211    pub fn plane_offset(&self) -> Option<usize> {
212        dispatch!(self, plane_offset)
213    }
214
215    /// Set the byte offset within the DMA-BUF where image data starts.
216    pub fn set_plane_offset(&mut self, offset: usize) {
217        dispatch!(self, set_plane_offset, offset)
218    }
219
220    /// Builder-style: set plane offset, consuming and returning self.
221    pub fn with_plane_offset(mut self, offset: usize) -> Self {
222        self.set_plane_offset(offset);
223        self
224    }
225
226    /// Clone the file descriptor associated with this tensor.
227    #[cfg(unix)]
228    pub fn clone_fd(&self) -> crate::Result<std::os::fd::OwnedFd> {
229        dispatch!(self, clone_fd)
230    }
231
232    /// Clone the DMA-BUF file descriptor backing this tensor (Linux only).
233    ///
234    /// # Returns
235    ///
236    /// An owned duplicate of the DMA-BUF file descriptor.
237    ///
238    /// # Errors
239    ///
240    /// * `Error::NotImplemented` if the tensor is not DMA-backed (Mem/Shm/Pbo)
241    /// * `Error::IoError` if the fd clone syscall fails (e.g., fd limit reached)
242    #[cfg(target_os = "linux")]
243    pub fn dmabuf_clone(&self) -> crate::Result<std::os::fd::OwnedFd> {
244        if self.memory() != TensorMemory::Dma {
245            return Err(crate::Error::NotImplemented(format!(
246                "dmabuf_clone requires DMA-backed tensor, got {:?}",
247                self.memory()
248            )));
249        }
250        self.clone_fd()
251    }
252
253    /// Borrow the DMA-BUF file descriptor backing this tensor (Linux only).
254    ///
255    /// # Returns
256    ///
257    /// A borrowed reference to the DMA-BUF file descriptor, tied to `self`'s
258    /// lifetime.
259    ///
260    /// # Errors
261    ///
262    /// * `Error::NotImplemented` if the tensor is not DMA-backed
263    #[cfg(target_os = "linux")]
264    pub fn dmabuf(&self) -> crate::Result<std::os::fd::BorrowedFd<'_>> {
265        dispatch!(self, dmabuf)
266    }
267
268    /// Return `true` if this tensor uses separate plane allocations.
269    pub fn is_multiplane(&self) -> bool {
270        dispatch!(self, is_multiplane)
271    }
272
273    // --- Downcasting ---
274
275    downcast_methods!(U8, u8, as_u8, as_u8_mut, into_u8);
276    downcast_methods!(I8, i8, as_i8, as_i8_mut, into_i8);
277    downcast_methods!(U16, u16, as_u16, as_u16_mut, into_u16);
278    downcast_methods!(I16, i16, as_i16, as_i16_mut, into_i16);
279    downcast_methods!(U32, u32, as_u32, as_u32_mut, into_u32);
280    downcast_methods!(I32, i32, as_i32, as_i32_mut, into_i32);
281    downcast_methods!(U64, u64, as_u64, as_u64_mut, into_u64);
282    downcast_methods!(I64, i64, as_i64, as_i64_mut, into_i64);
283    downcast_methods!(F16, f16, as_f16, as_f16_mut, into_f16);
284    downcast_methods!(F32, f32, as_f32, as_f32_mut, into_f32);
285    downcast_methods!(F64, f64, as_f64, as_f64_mut, into_f64);
286
287    /// Create a type-erased tensor with the given shape and element type.
288    pub fn new(
289        shape: &[usize],
290        dtype: DType,
291        memory: Option<TensorMemory>,
292        name: Option<&str>,
293    ) -> crate::Result<Self> {
294        match dtype {
295            DType::U8 => Tensor::<u8>::new(shape, memory, name).map(Self::U8),
296            DType::I8 => Tensor::<i8>::new(shape, memory, name).map(Self::I8),
297            DType::U16 => Tensor::<u16>::new(shape, memory, name).map(Self::U16),
298            DType::I16 => Tensor::<i16>::new(shape, memory, name).map(Self::I16),
299            DType::U32 => Tensor::<u32>::new(shape, memory, name).map(Self::U32),
300            DType::I32 => Tensor::<i32>::new(shape, memory, name).map(Self::I32),
301            DType::U64 => Tensor::<u64>::new(shape, memory, name).map(Self::U64),
302            DType::I64 => Tensor::<i64>::new(shape, memory, name).map(Self::I64),
303            DType::F16 => Tensor::<f16>::new(shape, memory, name).map(Self::F16),
304            DType::F32 => Tensor::<f32>::new(shape, memory, name).map(Self::F32),
305            DType::F64 => Tensor::<f64>::new(shape, memory, name).map(Self::F64),
306        }
307    }
308
309    /// Create a type-erased tensor from a file descriptor.
310    #[cfg(unix)]
311    pub fn from_fd(
312        fd: std::os::fd::OwnedFd,
313        shape: &[usize],
314        dtype: DType,
315        name: Option<&str>,
316    ) -> crate::Result<Self> {
317        match dtype {
318            DType::U8 => Tensor::<u8>::from_fd(fd, shape, name).map(Self::U8),
319            DType::I8 => Tensor::<i8>::from_fd(fd, shape, name).map(Self::I8),
320            DType::U16 => Tensor::<u16>::from_fd(fd, shape, name).map(Self::U16),
321            DType::I16 => Tensor::<i16>::from_fd(fd, shape, name).map(Self::I16),
322            DType::U32 => Tensor::<u32>::from_fd(fd, shape, name).map(Self::U32),
323            DType::I32 => Tensor::<i32>::from_fd(fd, shape, name).map(Self::I32),
324            DType::U64 => Tensor::<u64>::from_fd(fd, shape, name).map(Self::U64),
325            DType::I64 => Tensor::<i64>::from_fd(fd, shape, name).map(Self::I64),
326            DType::F16 => Tensor::<f16>::from_fd(fd, shape, name).map(Self::F16),
327            DType::F32 => Tensor::<f32>::from_fd(fd, shape, name).map(Self::F32),
328            DType::F64 => Tensor::<f64>::from_fd(fd, shape, name).map(Self::F64),
329        }
330    }
331
332    /// Create a type-erased image tensor.
333    ///
334    /// # Arguments
335    ///
336    /// * `width` - Image width in pixels
337    /// * `height` - Image height in pixels
338    /// * `format` - Pixel format
339    /// * `dtype` - Element type discriminant
340    /// * `memory` - Optional memory backend (None selects the best available)
341    ///
342    /// # Returns
343    ///
344    /// A new `TensorDyn` wrapping an image tensor of the requested element type.
345    ///
346    /// # Errors
347    ///
348    /// Returns an error if the underlying `Tensor::image` call fails.
349    pub fn image(
350        width: usize,
351        height: usize,
352        format: PixelFormat,
353        dtype: DType,
354        memory: Option<TensorMemory>,
355    ) -> crate::Result<Self> {
356        match dtype {
357            DType::U8 => Tensor::<u8>::image(width, height, format, memory).map(Self::U8),
358            DType::I8 => Tensor::<i8>::image(width, height, format, memory).map(Self::I8),
359            DType::U16 => Tensor::<u16>::image(width, height, format, memory).map(Self::U16),
360            DType::I16 => Tensor::<i16>::image(width, height, format, memory).map(Self::I16),
361            DType::U32 => Tensor::<u32>::image(width, height, format, memory).map(Self::U32),
362            DType::I32 => Tensor::<i32>::image(width, height, format, memory).map(Self::I32),
363            DType::U64 => Tensor::<u64>::image(width, height, format, memory).map(Self::U64),
364            DType::I64 => Tensor::<i64>::image(width, height, format, memory).map(Self::I64),
365            DType::F16 => Tensor::<f16>::image(width, height, format, memory).map(Self::F16),
366            DType::F32 => Tensor::<f32>::image(width, height, format, memory).map(Self::F32),
367            DType::F64 => Tensor::<f64>::image(width, height, format, memory).map(Self::F64),
368        }
369    }
370
371    /// Create a DMA-backed image tensor with an explicit row stride that
372    /// may exceed the natural `width * channels * sizeof(T)` pitch.
373    ///
374    /// See [`Tensor::image_with_stride`] for the detailed contract and
375    /// constraints. The TensorDyn wrapper dispatches to the appropriate
376    /// monomorphised `Tensor<T>` based on `dtype`.
377    ///
378    /// # Example
379    ///
380    /// ```no_run
381    /// use edgefirst_tensor::{TensorDyn, PixelFormat, DType, TensorMemory};
382    /// # fn main() -> edgefirst_tensor::Result<()> {
383    /// // Allocate a 3004×1688 RGBA8 canvas with 64-byte pitch alignment
384    /// // (12032 bytes per row instead of the natural 12016).
385    /// let img = TensorDyn::image_with_stride(
386    ///     3004, 1688,
387    ///     PixelFormat::Rgba, DType::U8,
388    ///     12032,
389    ///     Some(TensorMemory::Dma),
390    /// )?;
391    /// assert_eq!(img.width(), Some(3004));       // logical, unchanged
392    /// assert_eq!(img.effective_row_stride(), Some(12032)); // padded
393    /// # Ok(())
394    /// # }
395    /// ```
396    pub fn image_with_stride(
397        width: usize,
398        height: usize,
399        format: PixelFormat,
400        dtype: DType,
401        row_stride_bytes: usize,
402        memory: Option<TensorMemory>,
403    ) -> crate::Result<Self> {
404        match dtype {
405            DType::U8 => {
406                Tensor::<u8>::image_with_stride(width, height, format, row_stride_bytes, memory)
407                    .map(Self::U8)
408            }
409            DType::I8 => {
410                Tensor::<i8>::image_with_stride(width, height, format, row_stride_bytes, memory)
411                    .map(Self::I8)
412            }
413            DType::U16 => {
414                Tensor::<u16>::image_with_stride(width, height, format, row_stride_bytes, memory)
415                    .map(Self::U16)
416            }
417            DType::I16 => {
418                Tensor::<i16>::image_with_stride(width, height, format, row_stride_bytes, memory)
419                    .map(Self::I16)
420            }
421            DType::U32 => {
422                Tensor::<u32>::image_with_stride(width, height, format, row_stride_bytes, memory)
423                    .map(Self::U32)
424            }
425            DType::I32 => {
426                Tensor::<i32>::image_with_stride(width, height, format, row_stride_bytes, memory)
427                    .map(Self::I32)
428            }
429            DType::U64 => {
430                Tensor::<u64>::image_with_stride(width, height, format, row_stride_bytes, memory)
431                    .map(Self::U64)
432            }
433            DType::I64 => {
434                Tensor::<i64>::image_with_stride(width, height, format, row_stride_bytes, memory)
435                    .map(Self::I64)
436            }
437            DType::F16 => {
438                Tensor::<f16>::image_with_stride(width, height, format, row_stride_bytes, memory)
439                    .map(Self::F16)
440            }
441            DType::F32 => {
442                Tensor::<f32>::image_with_stride(width, height, format, row_stride_bytes, memory)
443                    .map(Self::F32)
444            }
445            DType::F64 => {
446                Tensor::<f64>::image_with_stride(width, height, format, row_stride_bytes, memory)
447                    .map(Self::F64)
448            }
449        }
450    }
451}
452
453// --- From impls ---
454
455impl From<Tensor<u8>> for TensorDyn {
456    fn from(t: Tensor<u8>) -> Self {
457        Self::U8(t)
458    }
459}
460
461impl From<Tensor<i8>> for TensorDyn {
462    fn from(t: Tensor<i8>) -> Self {
463        Self::I8(t)
464    }
465}
466
467impl From<Tensor<u16>> for TensorDyn {
468    fn from(t: Tensor<u16>) -> Self {
469        Self::U16(t)
470    }
471}
472
473impl From<Tensor<i16>> for TensorDyn {
474    fn from(t: Tensor<i16>) -> Self {
475        Self::I16(t)
476    }
477}
478
479impl From<Tensor<u32>> for TensorDyn {
480    fn from(t: Tensor<u32>) -> Self {
481        Self::U32(t)
482    }
483}
484
485impl From<Tensor<i32>> for TensorDyn {
486    fn from(t: Tensor<i32>) -> Self {
487        Self::I32(t)
488    }
489}
490
491impl From<Tensor<u64>> for TensorDyn {
492    fn from(t: Tensor<u64>) -> Self {
493        Self::U64(t)
494    }
495}
496
497impl From<Tensor<i64>> for TensorDyn {
498    fn from(t: Tensor<i64>) -> Self {
499        Self::I64(t)
500    }
501}
502
503impl From<Tensor<f16>> for TensorDyn {
504    fn from(t: Tensor<f16>) -> Self {
505        Self::F16(t)
506    }
507}
508
509impl From<Tensor<f32>> for TensorDyn {
510    fn from(t: Tensor<f32>) -> Self {
511        Self::F32(t)
512    }
513}
514
515impl From<Tensor<f64>> for TensorDyn {
516    fn from(t: Tensor<f64>) -> Self {
517        Self::F64(t)
518    }
519}
520
521impl fmt::Debug for TensorDyn {
522    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
523        dispatch!(self, fmt, f)
524    }
525}
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530
531    #[test]
532    fn from_typed_tensor() {
533        let t = Tensor::<u8>::new(&[10], None, None).unwrap();
534        let dyn_t: TensorDyn = t.into();
535        assert_eq!(dyn_t.dtype(), DType::U8);
536        assert_eq!(dyn_t.shape(), &[10]);
537    }
538
539    #[test]
540    fn downcast_ref() {
541        let t = Tensor::<u8>::new(&[10], None, None).unwrap();
542        let dyn_t: TensorDyn = t.into();
543        assert!(dyn_t.as_u8().is_some());
544        assert!(dyn_t.as_i8().is_none());
545    }
546
547    #[test]
548    fn downcast_into() {
549        let t = Tensor::<u8>::new(&[10], None, None).unwrap();
550        let dyn_t: TensorDyn = t.into();
551        let back = dyn_t.into_u8().unwrap();
552        assert_eq!(back.shape(), &[10]);
553    }
554
555    #[test]
556    fn image_accessors() {
557        let t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
558        let dyn_t: TensorDyn = t.into();
559        assert_eq!(dyn_t.format(), Some(PixelFormat::Rgba));
560        assert_eq!(dyn_t.width(), Some(640));
561        assert_eq!(dyn_t.height(), Some(480));
562        assert!(!dyn_t.is_multiplane());
563    }
564
565    #[test]
566    fn image_constructor() {
567        let dyn_t = TensorDyn::image(640, 480, PixelFormat::Rgb, DType::U8, None).unwrap();
568        assert_eq!(dyn_t.dtype(), DType::U8);
569        assert_eq!(dyn_t.format(), Some(PixelFormat::Rgb));
570        assert_eq!(dyn_t.width(), Some(640));
571    }
572
573    #[test]
574    fn image_constructor_i8() {
575        let dyn_t = TensorDyn::image(640, 480, PixelFormat::Rgb, DType::I8, None).unwrap();
576        assert_eq!(dyn_t.dtype(), DType::I8);
577        assert_eq!(dyn_t.format(), Some(PixelFormat::Rgb));
578    }
579
580    #[test]
581    fn set_format_packed() {
582        let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
583        assert_eq!(t.format(), None);
584        t.set_format(PixelFormat::Rgb).unwrap();
585        assert_eq!(t.format(), Some(PixelFormat::Rgb));
586        assert_eq!(t.width(), Some(640));
587        assert_eq!(t.height(), Some(480));
588    }
589
590    #[test]
591    fn set_format_planar() {
592        let mut t = TensorDyn::new(&[3, 480, 640], DType::U8, None, None).unwrap();
593        t.set_format(PixelFormat::PlanarRgb).unwrap();
594        assert_eq!(t.format(), Some(PixelFormat::PlanarRgb));
595        assert_eq!(t.width(), Some(640));
596        assert_eq!(t.height(), Some(480));
597    }
598
599    #[test]
600    fn set_format_rejects_wrong_shape() {
601        let mut t = TensorDyn::new(&[480, 640, 4], DType::U8, None, None).unwrap();
602        assert!(t.set_format(PixelFormat::Rgb).is_err());
603    }
604
605    #[test]
606    fn with_format_builder() {
607        let t = TensorDyn::new(&[480, 640, 4], DType::U8, None, None)
608            .unwrap()
609            .with_format(PixelFormat::Rgba)
610            .unwrap();
611        assert_eq!(t.format(), Some(PixelFormat::Rgba));
612        assert_eq!(t.width(), Some(640));
613        assert_eq!(t.height(), Some(480));
614    }
615
616    #[cfg(target_os = "linux")]
617    #[test]
618    fn dmabuf_clone_mem_tensor_fails() {
619        let t = TensorDyn::new(&[480, 640, 3], DType::U8, Some(TensorMemory::Mem), None).unwrap();
620        assert_eq!(t.memory(), TensorMemory::Mem);
621        assert!(t.dmabuf_clone().is_err());
622    }
623
624    #[cfg(target_os = "linux")]
625    #[test]
626    fn dmabuf_mem_tensor_fails() {
627        let t = TensorDyn::new(&[480, 640, 3], DType::U8, Some(TensorMemory::Mem), None).unwrap();
628        assert!(t.dmabuf().is_err());
629    }
630
631    #[test]
632    fn set_format_semi_planar_nv12() {
633        // 720 rows = 480 * 3/2 (NV12: height + height/2 for chroma)
634        let mut t = TensorDyn::new(&[720, 640], DType::U8, Some(TensorMemory::Mem), None).unwrap();
635        t.set_format(PixelFormat::Nv12).unwrap();
636        assert_eq!(t.format(), Some(PixelFormat::Nv12));
637        assert_eq!(t.width(), Some(640));
638        assert_eq!(t.height(), Some(480));
639    }
640
641    #[test]
642    fn set_format_semi_planar_nv16() {
643        // 960 rows = 480 * 2 (NV16: height + height for chroma)
644        let mut t = TensorDyn::new(&[960, 640], DType::U8, Some(TensorMemory::Mem), None).unwrap();
645        t.set_format(PixelFormat::Nv16).unwrap();
646        assert_eq!(t.format(), Some(PixelFormat::Nv16));
647        assert_eq!(t.width(), Some(640));
648        assert_eq!(t.height(), Some(480));
649    }
650
651    #[test]
652    fn with_format_rejects_wrong_shape() {
653        let result = TensorDyn::new(&[480, 640, 4], DType::U8, None, None)
654            .unwrap()
655            .with_format(PixelFormat::Rgb);
656        assert!(result.is_err());
657    }
658
659    #[test]
660    fn set_format_preserved_after_rejection() {
661        let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
662        t.set_format(PixelFormat::Rgb).unwrap();
663        assert_eq!(t.format(), Some(PixelFormat::Rgb));
664
665        // Rgba requires 4 channels, should fail on a 3-channel tensor
666        assert!(t.set_format(PixelFormat::Rgba).is_err());
667
668        // Original format should be preserved
669        assert_eq!(t.format(), Some(PixelFormat::Rgb));
670    }
671
672    #[test]
673    fn set_format_idempotent() {
674        let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
675        t.set_format(PixelFormat::Rgb).unwrap();
676        t.set_format(PixelFormat::Rgb).unwrap();
677        assert_eq!(t.format(), Some(PixelFormat::Rgb));
678        assert_eq!(t.width(), Some(640));
679        assert_eq!(t.height(), Some(480));
680    }
681
682    // --- Row stride tests ---
683
684    #[test]
685    fn set_row_stride_valid() {
686        // RGBA 100px wide: min stride = 400, set 512
687        let mut t = TensorDyn::image(100, 100, PixelFormat::Rgba, DType::U8, None).unwrap();
688        t.set_row_stride(512).unwrap();
689        assert_eq!(t.row_stride(), Some(512));
690        assert_eq!(t.effective_row_stride(), Some(512));
691    }
692
693    #[test]
694    fn set_row_stride_equals_min() {
695        // RGB 100px: min stride = 300, set exactly 300
696        let mut t = TensorDyn::image(100, 100, PixelFormat::Rgb, DType::U8, None).unwrap();
697        t.set_row_stride(300).unwrap();
698        assert_eq!(t.row_stride(), Some(300));
699    }
700
701    #[test]
702    fn set_row_stride_too_small() {
703        // RGBA 100px: min stride = 400, set 300
704        let mut t = TensorDyn::image(100, 100, PixelFormat::Rgba, DType::U8, None).unwrap();
705        assert!(t.set_row_stride(300).is_err());
706        assert_eq!(t.row_stride(), None);
707    }
708
709    #[test]
710    fn set_row_stride_zero() {
711        let mut t = TensorDyn::image(100, 100, PixelFormat::Rgb, DType::U8, None).unwrap();
712        assert!(t.set_row_stride(0).is_err());
713    }
714
715    #[test]
716    fn set_row_stride_requires_format() {
717        let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
718        assert!(t.set_row_stride(2048).is_err());
719    }
720
721    #[test]
722    fn effective_row_stride_without_stride() {
723        let t = TensorDyn::image(100, 100, PixelFormat::Rgb, DType::U8, None).unwrap();
724        assert_eq!(t.row_stride(), None);
725        assert_eq!(t.effective_row_stride(), Some(300)); // 100 * 3
726    }
727
728    #[test]
729    fn effective_row_stride_no_format() {
730        let t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
731        assert_eq!(t.effective_row_stride(), None);
732    }
733
734    #[test]
735    fn with_row_stride_builder() {
736        let t = TensorDyn::image(100, 100, PixelFormat::Rgba, DType::U8, None)
737            .unwrap()
738            .with_row_stride(512)
739            .unwrap();
740        assert_eq!(t.row_stride(), Some(512));
741        assert_eq!(t.effective_row_stride(), Some(512));
742    }
743
744    #[test]
745    fn with_row_stride_rejects_small() {
746        let result = TensorDyn::image(100, 100, PixelFormat::Rgba, DType::U8, None)
747            .unwrap()
748            .with_row_stride(200);
749        assert!(result.is_err());
750    }
751
752    #[test]
753    fn set_format_clears_row_stride() {
754        let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
755        t.set_format(PixelFormat::Rgb).unwrap();
756        t.set_row_stride(2048).unwrap();
757        assert_eq!(t.row_stride(), Some(2048));
758
759        // Incompatible format change (4-chan on 3-chan shape) fails — stride preserved
760        let _ = t.set_format(PixelFormat::Bgra);
761        assert_eq!(t.row_stride(), Some(2048));
762
763        // Re-set to same format — stride preserved
764        t.set_format(PixelFormat::Rgb).unwrap();
765        assert_eq!(t.row_stride(), Some(2048));
766
767        // Reshape clears format and stride
768        t.reshape(&[480 * 640 * 3]).unwrap();
769        assert_eq!(t.row_stride(), None);
770        assert_eq!(t.format(), None);
771    }
772
773    #[test]
774    fn set_format_different_compatible_clears_stride() {
775        // RGBA and BGRA are both 4-channel packed — switching between them
776        // succeeds and must clear the stored stride.
777        let mut t = TensorDyn::new(&[480, 640, 4], DType::U8, None, None).unwrap();
778        t.set_format(PixelFormat::Rgba).unwrap();
779        t.set_row_stride(4096).unwrap();
780        assert_eq!(t.row_stride(), Some(4096));
781
782        // Successful format change to a different compatible format clears stride
783        t.set_format(PixelFormat::Bgra).unwrap();
784        assert_eq!(t.format(), Some(PixelFormat::Bgra));
785        assert_eq!(t.row_stride(), None);
786    }
787
788    #[test]
789    fn set_format_same_preserves_stride() {
790        let mut t = TensorDyn::image(100, 100, PixelFormat::Rgb, DType::U8, None).unwrap();
791        t.set_row_stride(512).unwrap();
792        // Re-setting the same format should not clear stride
793        t.set_format(PixelFormat::Rgb).unwrap();
794        assert_eq!(t.row_stride(), Some(512));
795    }
796
797    #[test]
798    fn effective_row_stride_planar() {
799        let t = TensorDyn::image(640, 480, PixelFormat::PlanarRgb, DType::U8, None).unwrap();
800        assert_eq!(t.effective_row_stride(), Some(640)); // planar: width only
801    }
802
803    #[test]
804    fn effective_row_stride_nv12() {
805        let t = TensorDyn::image(640, 480, PixelFormat::Nv12, DType::U8, None).unwrap();
806        assert_eq!(t.effective_row_stride(), Some(640)); // semi-planar: width only
807    }
808
809    #[test]
810    fn map_rejects_strided_tensor() {
811        let mut t =
812            Tensor::<u8>::image(100, 100, PixelFormat::Rgba, Some(TensorMemory::Mem)).unwrap();
813        // Map works before stride is set
814        assert!(t.map().is_ok());
815        // After setting stride, map should be rejected
816        t.set_row_stride(512).unwrap();
817        let err = t.map();
818        assert!(err.is_err());
819    }
820
821    // ── plane_offset tests ──────────────────────────────────────────
822
823    #[test]
824    fn plane_offset_default_none() {
825        let t = TensorDyn::image(100, 100, PixelFormat::Rgba, DType::U8, None).unwrap();
826        assert_eq!(t.plane_offset(), None);
827    }
828
829    #[test]
830    fn set_plane_offset_basic() {
831        let mut t = TensorDyn::image(100, 100, PixelFormat::Rgba, DType::U8, None).unwrap();
832        t.set_plane_offset(4096);
833        assert_eq!(t.plane_offset(), Some(4096));
834    }
835
836    #[test]
837    fn set_plane_offset_zero() {
838        let mut t = TensorDyn::image(100, 100, PixelFormat::Rgb, DType::U8, None).unwrap();
839        t.set_plane_offset(0);
840        assert_eq!(t.plane_offset(), Some(0));
841    }
842
843    #[test]
844    fn set_plane_offset_no_format() {
845        // plane_offset does not require format (it is format-independent)
846        let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
847        t.set_plane_offset(4096);
848        assert_eq!(t.plane_offset(), Some(4096));
849    }
850
851    #[test]
852    fn with_plane_offset_builder() {
853        let t = TensorDyn::image(100, 100, PixelFormat::Rgba, DType::U8, None)
854            .unwrap()
855            .with_plane_offset(8192);
856        assert_eq!(t.plane_offset(), Some(8192));
857    }
858
859    #[test]
860    fn set_format_clears_plane_offset() {
861        let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
862        t.set_format(PixelFormat::Rgb).unwrap();
863        t.set_plane_offset(4096);
864        assert_eq!(t.plane_offset(), Some(4096));
865
866        // Re-set same format — offset preserved
867        t.set_format(PixelFormat::Rgb).unwrap();
868        assert_eq!(t.plane_offset(), Some(4096));
869
870        // Reshape clears everything
871        t.reshape(&[480 * 640 * 3]).unwrap();
872        assert_eq!(t.plane_offset(), None);
873        assert_eq!(t.format(), None);
874    }
875
876    #[test]
877    fn map_rejects_offset_tensor() {
878        let mut t =
879            Tensor::<u8>::image(100, 100, PixelFormat::Rgba, Some(TensorMemory::Mem)).unwrap();
880        // Map works before offset is set
881        assert!(t.map().is_ok());
882        // After setting non-zero offset, map should be rejected
883        t.set_plane_offset(4096);
884        assert!(t.map().is_err());
885    }
886
887    #[test]
888    fn map_accepts_zero_offset_tensor() {
889        let mut t =
890            Tensor::<u8>::image(100, 100, PixelFormat::Rgba, Some(TensorMemory::Mem)).unwrap();
891        t.set_plane_offset(0);
892        // Zero offset is fine for CPU mapping
893        assert!(t.map().is_ok());
894    }
895
896    #[test]
897    fn from_planes_propagates_plane_offset() {
898        let mut luma =
899            Tensor::<u8>::new(&[480, 640], Some(TensorMemory::Mem), Some("luma")).unwrap();
900        luma.set_plane_offset(4096);
901        let chroma =
902            Tensor::<u8>::new(&[240, 640], Some(TensorMemory::Mem), Some("chroma")).unwrap();
903        let combined = Tensor::<u8>::from_planes(luma, chroma, PixelFormat::Nv12).unwrap();
904        assert_eq!(combined.plane_offset(), Some(4096));
905    }
906}