Skip to main content

edgefirst_tensor/
tensor_dyn.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{DType, PixelFormat, Tensor, TensorMemory, TensorTrait};
5use half::f16;
6use std::fmt;
7
8/// Type-erased tensor. Wraps a `Tensor<T>` with runtime element type.
9#[non_exhaustive]
10pub enum TensorDyn {
11    /// Unsigned 8-bit integer tensor.
12    U8(Tensor<u8>),
13    /// Signed 8-bit integer tensor.
14    I8(Tensor<i8>),
15    /// Unsigned 16-bit integer tensor.
16    U16(Tensor<u16>),
17    /// Signed 16-bit integer tensor.
18    I16(Tensor<i16>),
19    /// Unsigned 32-bit integer tensor.
20    U32(Tensor<u32>),
21    /// Signed 32-bit integer tensor.
22    I32(Tensor<i32>),
23    /// Unsigned 64-bit integer tensor.
24    U64(Tensor<u64>),
25    /// Signed 64-bit integer tensor.
26    I64(Tensor<i64>),
27    /// 16-bit floating-point tensor.
28    F16(Tensor<f16>),
29    /// 32-bit floating-point tensor.
30    F32(Tensor<f32>),
31    /// 64-bit floating-point tensor.
32    F64(Tensor<f64>),
33}
34
35/// Dispatch a method call across all TensorDyn variants.
36macro_rules! dispatch {
37    ($self:expr, $method:ident $(, $arg:expr)*) => {
38        match $self {
39            TensorDyn::U8(t) => t.$method($($arg),*),
40            TensorDyn::I8(t) => t.$method($($arg),*),
41            TensorDyn::U16(t) => t.$method($($arg),*),
42            TensorDyn::I16(t) => t.$method($($arg),*),
43            TensorDyn::U32(t) => t.$method($($arg),*),
44            TensorDyn::I32(t) => t.$method($($arg),*),
45            TensorDyn::U64(t) => t.$method($($arg),*),
46            TensorDyn::I64(t) => t.$method($($arg),*),
47            TensorDyn::F16(t) => t.$method($($arg),*),
48            TensorDyn::F32(t) => t.$method($($arg),*),
49            TensorDyn::F64(t) => t.$method($($arg),*),
50        }
51    };
52}
53
54/// Generate the three downcast methods (ref, mut ref, owned) for one variant.
55macro_rules! downcast_methods {
56    ($variant:ident, $ty:ty, $as_name:ident, $as_mut_name:ident, $into_name:ident) => {
57        /// Returns a shared reference to the inner tensor if the type matches.
58        pub fn $as_name(&self) -> Option<&Tensor<$ty>> {
59            match self {
60                Self::$variant(t) => Some(t),
61                _ => None,
62            }
63        }
64
65        /// Returns a mutable reference to the inner tensor if the type matches.
66        pub fn $as_mut_name(&mut self) -> Option<&mut Tensor<$ty>> {
67            match self {
68                Self::$variant(t) => Some(t),
69                _ => None,
70            }
71        }
72
73        /// Unwraps the inner tensor if the type matches, otherwise returns `self` as `Err`.
74        /// The Err variant is necessarily large (returns the unconsumed TensorDyn).
75        #[allow(clippy::result_large_err)]
76        pub fn $into_name(self) -> Result<Tensor<$ty>, Self> {
77            match self {
78                Self::$variant(t) => Ok(t),
79                other => Err(other),
80            }
81        }
82    };
83}
84
85impl TensorDyn {
86    /// Return the runtime element type discriminant.
87    pub fn dtype(&self) -> DType {
88        match self {
89            Self::U8(_) => DType::U8,
90            Self::I8(_) => DType::I8,
91            Self::U16(_) => DType::U16,
92            Self::I16(_) => DType::I16,
93            Self::U32(_) => DType::U32,
94            Self::I32(_) => DType::I32,
95            Self::U64(_) => DType::U64,
96            Self::I64(_) => DType::I64,
97            Self::F16(_) => DType::F16,
98            Self::F32(_) => DType::F32,
99            Self::F64(_) => DType::F64,
100        }
101    }
102
103    /// Return the tensor shape.
104    pub fn shape(&self) -> &[usize] {
105        dispatch!(self, shape)
106    }
107
108    /// Return the tensor name.
109    pub fn name(&self) -> String {
110        dispatch!(self, name)
111    }
112
113    /// Return the pixel format (None if not an image tensor).
114    pub fn format(&self) -> Option<PixelFormat> {
115        dispatch!(self, format)
116    }
117
118    /// Return the image width (None if not an image tensor).
119    pub fn width(&self) -> Option<usize> {
120        dispatch!(self, width)
121    }
122
123    /// Return the image height (None if not an image tensor).
124    pub fn height(&self) -> Option<usize> {
125        dispatch!(self, height)
126    }
127
128    /// Return the total size of this tensor in bytes.
129    pub fn size(&self) -> usize {
130        dispatch!(self, size)
131    }
132
133    /// Return the memory allocation type.
134    pub fn memory(&self) -> TensorMemory {
135        dispatch!(self, memory)
136    }
137
138    /// Reshape this tensor. Total element count must remain the same.
139    pub fn reshape(&mut self, shape: &[usize]) -> crate::Result<()> {
140        dispatch!(self, reshape, shape)
141    }
142
143    /// Attach pixel format metadata to this tensor.
144    ///
145    /// Validates that the tensor's shape is compatible with the format's
146    /// layout (packed, planar, or semi-planar).
147    ///
148    /// # Arguments
149    ///
150    /// * `format` - The pixel format to attach
151    ///
152    /// # Returns
153    ///
154    /// `Ok(())` on success, with the format stored as metadata on the tensor.
155    ///
156    /// # Errors
157    ///
158    /// Returns `Error::InvalidShape` if the tensor shape doesn't match
159    /// the expected layout for the given format.
160    pub fn set_format(&mut self, format: PixelFormat) -> crate::Result<()> {
161        dispatch!(self, set_format, format)
162    }
163
164    /// Attach pixel format metadata, consuming and returning self.
165    ///
166    /// Enables builder-style chaining.
167    ///
168    /// # Arguments
169    ///
170    /// * `format` - The pixel format to attach
171    ///
172    /// # Returns
173    ///
174    /// The tensor with format metadata attached.
175    ///
176    /// # Errors
177    ///
178    /// Returns `Error::InvalidShape` if the tensor shape doesn't match
179    /// the expected layout for the given format.
180    pub fn with_format(mut self, format: PixelFormat) -> crate::Result<Self> {
181        self.set_format(format)?;
182        Ok(self)
183    }
184
185    /// Row stride in bytes (`None` = tightly packed).
186    pub fn row_stride(&self) -> Option<usize> {
187        dispatch!(self, row_stride)
188    }
189
190    /// Effective row stride: stored stride or computed from format and width.
191    pub fn effective_row_stride(&self) -> Option<usize> {
192        dispatch!(self, effective_row_stride)
193    }
194
195    /// Set the row stride in bytes for externally allocated buffers with
196    /// row padding.
197    ///
198    /// Must be called before the tensor is first used for rendering. The
199    /// format must be set before calling this method.
200    pub fn set_row_stride(&mut self, stride: usize) -> crate::Result<()> {
201        dispatch!(self, set_row_stride, stride)
202    }
203
204    /// Builder-style: set row stride, consuming and returning self.
205    pub fn with_row_stride(mut self, stride: usize) -> crate::Result<Self> {
206        self.set_row_stride(stride)?;
207        Ok(self)
208    }
209
210    /// Byte offset within the DMA-BUF where image data starts (`None` = 0).
211    pub fn plane_offset(&self) -> Option<usize> {
212        dispatch!(self, plane_offset)
213    }
214
215    /// Set the byte offset within the DMA-BUF where image data starts.
216    pub fn set_plane_offset(&mut self, offset: usize) {
217        dispatch!(self, set_plane_offset, offset)
218    }
219
220    /// Builder-style: set plane offset, consuming and returning self.
221    pub fn with_plane_offset(mut self, offset: usize) -> Self {
222        self.set_plane_offset(offset);
223        self
224    }
225
226    /// Quantization metadata. Returns `None` for float variants (F16, F32,
227    /// F64) — quantization does not apply to floating-point tensors.
228    /// Otherwise delegates to the typed `Tensor<T>::quantization()` accessor.
229    pub fn quantization(&self) -> Option<&crate::Quantization> {
230        match self {
231            Self::U8(t) => t.quantization(),
232            Self::I8(t) => t.quantization(),
233            Self::U16(t) => t.quantization(),
234            Self::I16(t) => t.quantization(),
235            Self::U32(t) => t.quantization(),
236            Self::I32(t) => t.quantization(),
237            Self::U64(t) => t.quantization(),
238            Self::I64(t) => t.quantization(),
239            Self::F16(_) | Self::F32(_) | Self::F64(_) => None,
240        }
241    }
242
243    /// Attach quantization metadata. Fails on float variants with
244    /// [`Error::QuantizationInvalid`]; delegates to the typed setter for
245    /// integer variants.
246    pub fn set_quantization(&mut self, q: crate::Quantization) -> crate::Result<()> {
247        match self {
248            Self::U8(t) => t.set_quantization(q),
249            Self::I8(t) => t.set_quantization(q),
250            Self::U16(t) => t.set_quantization(q),
251            Self::I16(t) => t.set_quantization(q),
252            Self::U32(t) => t.set_quantization(q),
253            Self::I32(t) => t.set_quantization(q),
254            Self::U64(t) => t.set_quantization(q),
255            Self::I64(t) => t.set_quantization(q),
256            Self::F16(_) | Self::F32(_) | Self::F64(_) => Err(crate::Error::QuantizationInvalid {
257                field: "dtype_is_integer",
258                expected: "integer tensor dtype (u8/i8/u16/i16/u32/i32/u64/i64)".to_string(),
259                got: format!("{:?}", self.dtype()),
260            }),
261        }
262    }
263
264    /// Builder-style variant of [`Self::set_quantization`]. Consumes self
265    /// and returns it with quantization applied (or the original error).
266    pub fn with_quantization(mut self, q: crate::Quantization) -> crate::Result<Self> {
267        self.set_quantization(q)?;
268        Ok(self)
269    }
270
271    /// Clear any quantization metadata. No-op on float variants.
272    pub fn clear_quantization(&mut self) {
273        match self {
274            Self::U8(t) => t.clear_quantization(),
275            Self::I8(t) => t.clear_quantization(),
276            Self::U16(t) => t.clear_quantization(),
277            Self::I16(t) => t.clear_quantization(),
278            Self::U32(t) => t.clear_quantization(),
279            Self::I32(t) => t.clear_quantization(),
280            Self::U64(t) => t.clear_quantization(),
281            Self::I64(t) => t.clear_quantization(),
282            Self::F16(_) | Self::F32(_) | Self::F64(_) => {}
283        }
284    }
285
286    /// Clone the file descriptor associated with this tensor.
287    #[cfg(unix)]
288    pub fn clone_fd(&self) -> crate::Result<std::os::fd::OwnedFd> {
289        dispatch!(self, clone_fd)
290    }
291
292    /// Clone the DMA-BUF file descriptor backing this tensor (Linux only).
293    ///
294    /// # Returns
295    ///
296    /// An owned duplicate of the DMA-BUF file descriptor.
297    ///
298    /// # Errors
299    ///
300    /// * `Error::NotImplemented` if the tensor is not DMA-backed (Mem/Shm/Pbo)
301    /// * `Error::IoError` if the fd clone syscall fails (e.g., fd limit reached)
302    #[cfg(target_os = "linux")]
303    pub fn dmabuf_clone(&self) -> crate::Result<std::os::fd::OwnedFd> {
304        if self.memory() != TensorMemory::Dma {
305            return Err(crate::Error::NotImplemented(format!(
306                "dmabuf_clone requires DMA-backed tensor, got {:?}",
307                self.memory()
308            )));
309        }
310        self.clone_fd()
311    }
312
313    /// Borrow the DMA-BUF file descriptor backing this tensor (Linux only).
314    ///
315    /// # Returns
316    ///
317    /// A borrowed reference to the DMA-BUF file descriptor, tied to `self`'s
318    /// lifetime.
319    ///
320    /// # Errors
321    ///
322    /// * `Error::NotImplemented` if the tensor is not DMA-backed
323    #[cfg(target_os = "linux")]
324    pub fn dmabuf(&self) -> crate::Result<std::os::fd::BorrowedFd<'_>> {
325        dispatch!(self, dmabuf)
326    }
327
328    /// Return `true` if this tensor uses separate plane allocations.
329    pub fn is_multiplane(&self) -> bool {
330        dispatch!(self, is_multiplane)
331    }
332
333    /// Return the [`BufferIdentity`](crate::BufferIdentity) of the underlying
334    /// allocation.
335    ///
336    /// Two `TensorDyn` values share a [`BufferIdentity::id`] iff they were
337    /// produced by cloning the same allocation (e.g. through
338    /// [`DmaTensor::try_clone`](crate::dma::DmaTensor::try_clone)). Separate
339    /// imports of the same physical buffer (e.g. two `from_fd` calls on the
340    /// same dmabuf fd) have **distinct** identities — use
341    /// [`aliases`](Self::aliases) if you need to detect that case.
342    pub fn buffer_identity(&self) -> &crate::BufferIdentity {
343        dispatch!(self, buffer_identity)
344    }
345
346    /// Return `true` if `self` and `other` reference the same underlying
347    /// buffer.
348    ///
349    /// This is the correct check for APIs that require distinct input and
350    /// output tensors (e.g. `ImageProcessor::draw_decoded_masks`, where
351    /// aliasing `dst` and `background` would cause the GL backend to read
352    /// and write the same texture — undefined behaviour on most drivers).
353    ///
354    /// Matching is conservative:
355    /// 1. Matching [`BufferIdentity::id`] → same buffer (always).
356    /// 2. Matching backing type + matching dmabuf fd number (Linux, DMA
357    ///    tensors only) → same buffer, even across separate `from_fd`
358    ///    imports in the same process.
359    ///
360    /// Two distinct `dup`'d fds pointing at the same kernel dma-buf are
361    /// **not** detected — there is no cheap way to resolve that without a
362    /// round-trip through the kernel.
363    pub fn aliases(&self, other: &Self) -> bool {
364        if self.buffer_identity().id() == other.buffer_identity().id() {
365            return true;
366        }
367        if self.memory() != other.memory() {
368            return false;
369        }
370        #[cfg(target_os = "linux")]
371        if self.memory() == TensorMemory::Dma {
372            use std::os::fd::AsRawFd;
373            if let (Ok(a), Ok(b)) = (self.dmabuf(), other.dmabuf()) {
374                return a.as_raw_fd() == b.as_raw_fd();
375            }
376        }
377        false
378    }
379
380    // --- Downcasting ---
381
382    downcast_methods!(U8, u8, as_u8, as_u8_mut, into_u8);
383    downcast_methods!(I8, i8, as_i8, as_i8_mut, into_i8);
384    downcast_methods!(U16, u16, as_u16, as_u16_mut, into_u16);
385    downcast_methods!(I16, i16, as_i16, as_i16_mut, into_i16);
386    downcast_methods!(U32, u32, as_u32, as_u32_mut, into_u32);
387    downcast_methods!(I32, i32, as_i32, as_i32_mut, into_i32);
388    downcast_methods!(U64, u64, as_u64, as_u64_mut, into_u64);
389    downcast_methods!(I64, i64, as_i64, as_i64_mut, into_i64);
390    downcast_methods!(F16, f16, as_f16, as_f16_mut, into_f16);
391    downcast_methods!(F32, f32, as_f32, as_f32_mut, into_f32);
392    downcast_methods!(F64, f64, as_f64, as_f64_mut, into_f64);
393
394    /// Create a type-erased tensor with the given shape and element type.
395    pub fn new(
396        shape: &[usize],
397        dtype: DType,
398        memory: Option<TensorMemory>,
399        name: Option<&str>,
400    ) -> crate::Result<Self> {
401        match dtype {
402            DType::U8 => Tensor::<u8>::new(shape, memory, name).map(Self::U8),
403            DType::I8 => Tensor::<i8>::new(shape, memory, name).map(Self::I8),
404            DType::U16 => Tensor::<u16>::new(shape, memory, name).map(Self::U16),
405            DType::I16 => Tensor::<i16>::new(shape, memory, name).map(Self::I16),
406            DType::U32 => Tensor::<u32>::new(shape, memory, name).map(Self::U32),
407            DType::I32 => Tensor::<i32>::new(shape, memory, name).map(Self::I32),
408            DType::U64 => Tensor::<u64>::new(shape, memory, name).map(Self::U64),
409            DType::I64 => Tensor::<i64>::new(shape, memory, name).map(Self::I64),
410            DType::F16 => Tensor::<f16>::new(shape, memory, name).map(Self::F16),
411            DType::F32 => Tensor::<f32>::new(shape, memory, name).map(Self::F32),
412            DType::F64 => Tensor::<f64>::new(shape, memory, name).map(Self::F64),
413        }
414    }
415
416    /// Create a type-erased tensor from a file descriptor.
417    #[cfg(unix)]
418    pub fn from_fd(
419        fd: std::os::fd::OwnedFd,
420        shape: &[usize],
421        dtype: DType,
422        name: Option<&str>,
423    ) -> crate::Result<Self> {
424        match dtype {
425            DType::U8 => Tensor::<u8>::from_fd(fd, shape, name).map(Self::U8),
426            DType::I8 => Tensor::<i8>::from_fd(fd, shape, name).map(Self::I8),
427            DType::U16 => Tensor::<u16>::from_fd(fd, shape, name).map(Self::U16),
428            DType::I16 => Tensor::<i16>::from_fd(fd, shape, name).map(Self::I16),
429            DType::U32 => Tensor::<u32>::from_fd(fd, shape, name).map(Self::U32),
430            DType::I32 => Tensor::<i32>::from_fd(fd, shape, name).map(Self::I32),
431            DType::U64 => Tensor::<u64>::from_fd(fd, shape, name).map(Self::U64),
432            DType::I64 => Tensor::<i64>::from_fd(fd, shape, name).map(Self::I64),
433            DType::F16 => Tensor::<f16>::from_fd(fd, shape, name).map(Self::F16),
434            DType::F32 => Tensor::<f32>::from_fd(fd, shape, name).map(Self::F32),
435            DType::F64 => Tensor::<f64>::from_fd(fd, shape, name).map(Self::F64),
436        }
437    }
438
439    /// Wrap an externally-allocated IOSurface as a type-erased tensor
440    /// (macOS only).
441    ///
442    /// # Safety
443    ///
444    /// `surface_ref` must be a valid live `IOSurfaceRef`. `shape` must
445    /// match the IOSurface's pixel dimensions and chosen element type.
446    #[cfg(target_os = "macos")]
447    pub unsafe fn from_iosurface(
448        surface_ref: *mut std::ffi::c_void,
449        shape: &[usize],
450        dtype: DType,
451        name: Option<&str>,
452    ) -> crate::Result<Self> {
453        unsafe {
454            match dtype {
455                DType::U8 => Tensor::<u8>::from_iosurface(surface_ref, shape, name).map(Self::U8),
456                DType::I8 => Tensor::<i8>::from_iosurface(surface_ref, shape, name).map(Self::I8),
457                DType::U16 => {
458                    Tensor::<u16>::from_iosurface(surface_ref, shape, name).map(Self::U16)
459                }
460                DType::I16 => {
461                    Tensor::<i16>::from_iosurface(surface_ref, shape, name).map(Self::I16)
462                }
463                DType::U32 => {
464                    Tensor::<u32>::from_iosurface(surface_ref, shape, name).map(Self::U32)
465                }
466                DType::I32 => {
467                    Tensor::<i32>::from_iosurface(surface_ref, shape, name).map(Self::I32)
468                }
469                DType::U64 => {
470                    Tensor::<u64>::from_iosurface(surface_ref, shape, name).map(Self::U64)
471                }
472                DType::I64 => {
473                    Tensor::<i64>::from_iosurface(surface_ref, shape, name).map(Self::I64)
474                }
475                DType::F16 => {
476                    Tensor::<f16>::from_iosurface(surface_ref, shape, name).map(Self::F16)
477                }
478                DType::F32 => {
479                    Tensor::<f32>::from_iosurface(surface_ref, shape, name).map(Self::F32)
480                }
481                DType::F64 => {
482                    Tensor::<f64>::from_iosurface(surface_ref, shape, name).map(Self::F64)
483                }
484            }
485        }
486    }
487
488    /// IOSurfaceID for cross-process surface sharing (macOS only).
489    /// Returns `None` when the tensor is not IOSurface-backed.
490    #[cfg(target_os = "macos")]
491    pub fn iosurface_id(&self) -> Option<u32> {
492        dispatch!(self, iosurface_id)
493    }
494
495    /// Borrow the raw `IOSurfaceRef` backing this tensor (macOS only).
496    /// Returns `None` when the tensor is not IOSurface-backed. The
497    /// pointer's lifetime is tied to `self`.
498    #[cfg(target_os = "macos")]
499    pub fn iosurface_ref(&self) -> Option<*mut std::ffi::c_void> {
500        dispatch!(self, iosurface_ref)
501    }
502
503    /// Create a type-erased image tensor.
504    ///
505    /// # Arguments
506    ///
507    /// * `width` - Image width in pixels
508    /// * `height` - Image height in pixels
509    /// * `format` - Pixel format
510    /// * `dtype` - Element type discriminant
511    /// * `memory` - Optional memory backend (None selects the best available)
512    ///
513    /// # Returns
514    ///
515    /// A new `TensorDyn` wrapping an image tensor of the requested element type.
516    ///
517    /// # Errors
518    ///
519    /// Returns an error if the underlying `Tensor::image` call fails.
520    pub fn image(
521        width: usize,
522        height: usize,
523        format: PixelFormat,
524        dtype: DType,
525        memory: Option<TensorMemory>,
526    ) -> crate::Result<Self> {
527        match dtype {
528            DType::U8 => Tensor::<u8>::image(width, height, format, memory).map(Self::U8),
529            DType::I8 => Tensor::<i8>::image(width, height, format, memory).map(Self::I8),
530            DType::U16 => Tensor::<u16>::image(width, height, format, memory).map(Self::U16),
531            DType::I16 => Tensor::<i16>::image(width, height, format, memory).map(Self::I16),
532            DType::U32 => Tensor::<u32>::image(width, height, format, memory).map(Self::U32),
533            DType::I32 => Tensor::<i32>::image(width, height, format, memory).map(Self::I32),
534            DType::U64 => Tensor::<u64>::image(width, height, format, memory).map(Self::U64),
535            DType::I64 => Tensor::<i64>::image(width, height, format, memory).map(Self::I64),
536            DType::F16 => Tensor::<f16>::image(width, height, format, memory).map(Self::F16),
537            DType::F32 => Tensor::<f32>::image(width, height, format, memory).map(Self::F32),
538            DType::F64 => Tensor::<f64>::image(width, height, format, memory).map(Self::F64),
539        }
540    }
541
542    /// Create a DMA-backed image tensor with an explicit row stride that
543    /// may exceed the natural `width * channels * sizeof(T)` pitch.
544    ///
545    /// See [`Tensor::image_with_stride`] for the detailed contract and
546    /// constraints. The TensorDyn wrapper dispatches to the appropriate
547    /// monomorphised `Tensor<T>` based on `dtype`.
548    ///
549    /// # Example
550    ///
551    /// ```no_run
552    /// use edgefirst_tensor::{TensorDyn, PixelFormat, DType, TensorMemory};
553    /// # fn main() -> edgefirst_tensor::Result<()> {
554    /// // Allocate a 3004×1688 RGBA8 canvas with 64-byte pitch alignment
555    /// // (12032 bytes per row instead of the natural 12016).
556    /// let img = TensorDyn::image_with_stride(
557    ///     3004, 1688,
558    ///     PixelFormat::Rgba, DType::U8,
559    ///     12032,
560    ///     Some(TensorMemory::Dma),
561    /// )?;
562    /// assert_eq!(img.width(), Some(3004));       // logical, unchanged
563    /// assert_eq!(img.effective_row_stride(), Some(12032)); // padded
564    /// # Ok(())
565    /// # }
566    /// ```
567    pub fn image_with_stride(
568        width: usize,
569        height: usize,
570        format: PixelFormat,
571        dtype: DType,
572        row_stride_bytes: usize,
573        memory: Option<TensorMemory>,
574    ) -> crate::Result<Self> {
575        match dtype {
576            DType::U8 => {
577                Tensor::<u8>::image_with_stride(width, height, format, row_stride_bytes, memory)
578                    .map(Self::U8)
579            }
580            DType::I8 => {
581                Tensor::<i8>::image_with_stride(width, height, format, row_stride_bytes, memory)
582                    .map(Self::I8)
583            }
584            DType::U16 => {
585                Tensor::<u16>::image_with_stride(width, height, format, row_stride_bytes, memory)
586                    .map(Self::U16)
587            }
588            DType::I16 => {
589                Tensor::<i16>::image_with_stride(width, height, format, row_stride_bytes, memory)
590                    .map(Self::I16)
591            }
592            DType::U32 => {
593                Tensor::<u32>::image_with_stride(width, height, format, row_stride_bytes, memory)
594                    .map(Self::U32)
595            }
596            DType::I32 => {
597                Tensor::<i32>::image_with_stride(width, height, format, row_stride_bytes, memory)
598                    .map(Self::I32)
599            }
600            DType::U64 => {
601                Tensor::<u64>::image_with_stride(width, height, format, row_stride_bytes, memory)
602                    .map(Self::U64)
603            }
604            DType::I64 => {
605                Tensor::<i64>::image_with_stride(width, height, format, row_stride_bytes, memory)
606                    .map(Self::I64)
607            }
608            DType::F16 => {
609                Tensor::<f16>::image_with_stride(width, height, format, row_stride_bytes, memory)
610                    .map(Self::F16)
611            }
612            DType::F32 => {
613                Tensor::<f32>::image_with_stride(width, height, format, row_stride_bytes, memory)
614                    .map(Self::F32)
615            }
616            DType::F64 => {
617                Tensor::<f64>::image_with_stride(width, height, format, row_stride_bytes, memory)
618                    .map(Self::F64)
619            }
620        }
621    }
622}
623
624// --- From impls ---
625
626impl From<Tensor<u8>> for TensorDyn {
627    fn from(t: Tensor<u8>) -> Self {
628        Self::U8(t)
629    }
630}
631
632impl From<Tensor<i8>> for TensorDyn {
633    fn from(t: Tensor<i8>) -> Self {
634        Self::I8(t)
635    }
636}
637
638impl From<Tensor<u16>> for TensorDyn {
639    fn from(t: Tensor<u16>) -> Self {
640        Self::U16(t)
641    }
642}
643
644impl From<Tensor<i16>> for TensorDyn {
645    fn from(t: Tensor<i16>) -> Self {
646        Self::I16(t)
647    }
648}
649
650impl From<Tensor<u32>> for TensorDyn {
651    fn from(t: Tensor<u32>) -> Self {
652        Self::U32(t)
653    }
654}
655
656impl From<Tensor<i32>> for TensorDyn {
657    fn from(t: Tensor<i32>) -> Self {
658        Self::I32(t)
659    }
660}
661
662impl From<Tensor<u64>> for TensorDyn {
663    fn from(t: Tensor<u64>) -> Self {
664        Self::U64(t)
665    }
666}
667
668impl From<Tensor<i64>> for TensorDyn {
669    fn from(t: Tensor<i64>) -> Self {
670        Self::I64(t)
671    }
672}
673
674impl From<Tensor<f16>> for TensorDyn {
675    fn from(t: Tensor<f16>) -> Self {
676        Self::F16(t)
677    }
678}
679
680impl From<Tensor<f32>> for TensorDyn {
681    fn from(t: Tensor<f32>) -> Self {
682        Self::F32(t)
683    }
684}
685
686impl From<Tensor<f64>> for TensorDyn {
687    fn from(t: Tensor<f64>) -> Self {
688        Self::F64(t)
689    }
690}
691
692impl fmt::Debug for TensorDyn {
693    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
694        dispatch!(self, fmt, f)
695    }
696}
697
698#[cfg(test)]
699mod tests {
700    use super::*;
701
702    #[test]
703    fn from_typed_tensor() {
704        let t = Tensor::<u8>::new(&[10], None, None).unwrap();
705        let dyn_t: TensorDyn = t.into();
706        assert_eq!(dyn_t.dtype(), DType::U8);
707        assert_eq!(dyn_t.shape(), &[10]);
708    }
709
710    #[test]
711    fn downcast_ref() {
712        let t = Tensor::<u8>::new(&[10], None, None).unwrap();
713        let dyn_t: TensorDyn = t.into();
714        assert!(dyn_t.as_u8().is_some());
715        assert!(dyn_t.as_i8().is_none());
716    }
717
718    #[test]
719    fn downcast_into() {
720        let t = Tensor::<u8>::new(&[10], None, None).unwrap();
721        let dyn_t: TensorDyn = t.into();
722        let back = dyn_t.into_u8().unwrap();
723        assert_eq!(back.shape(), &[10]);
724    }
725
726    #[test]
727    fn image_accessors() {
728        let t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
729        let dyn_t: TensorDyn = t.into();
730        assert_eq!(dyn_t.format(), Some(PixelFormat::Rgba));
731        assert_eq!(dyn_t.width(), Some(640));
732        assert_eq!(dyn_t.height(), Some(480));
733        assert!(!dyn_t.is_multiplane());
734    }
735
736    #[test]
737    fn image_constructor() {
738        let dyn_t = TensorDyn::image(640, 480, PixelFormat::Rgb, DType::U8, None).unwrap();
739        assert_eq!(dyn_t.dtype(), DType::U8);
740        assert_eq!(dyn_t.format(), Some(PixelFormat::Rgb));
741        assert_eq!(dyn_t.width(), Some(640));
742    }
743
744    #[test]
745    fn image_constructor_i8() {
746        let dyn_t = TensorDyn::image(640, 480, PixelFormat::Rgb, DType::I8, None).unwrap();
747        assert_eq!(dyn_t.dtype(), DType::I8);
748        assert_eq!(dyn_t.format(), Some(PixelFormat::Rgb));
749    }
750
751    #[test]
752    fn set_format_packed() {
753        let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
754        assert_eq!(t.format(), None);
755        t.set_format(PixelFormat::Rgb).unwrap();
756        assert_eq!(t.format(), Some(PixelFormat::Rgb));
757        assert_eq!(t.width(), Some(640));
758        assert_eq!(t.height(), Some(480));
759    }
760
761    #[test]
762    fn set_format_planar() {
763        let mut t = TensorDyn::new(&[3, 480, 640], DType::U8, None, None).unwrap();
764        t.set_format(PixelFormat::PlanarRgb).unwrap();
765        assert_eq!(t.format(), Some(PixelFormat::PlanarRgb));
766        assert_eq!(t.width(), Some(640));
767        assert_eq!(t.height(), Some(480));
768    }
769
770    #[test]
771    fn set_format_rejects_wrong_shape() {
772        let mut t = TensorDyn::new(&[480, 640, 4], DType::U8, None, None).unwrap();
773        assert!(t.set_format(PixelFormat::Rgb).is_err());
774    }
775
776    #[test]
777    fn with_format_builder() {
778        let t = TensorDyn::new(&[480, 640, 4], DType::U8, None, None)
779            .unwrap()
780            .with_format(PixelFormat::Rgba)
781            .unwrap();
782        assert_eq!(t.format(), Some(PixelFormat::Rgba));
783        assert_eq!(t.width(), Some(640));
784        assert_eq!(t.height(), Some(480));
785    }
786
787    #[cfg(target_os = "linux")]
788    #[test]
789    fn dmabuf_clone_mem_tensor_fails() {
790        let t = TensorDyn::new(&[480, 640, 3], DType::U8, Some(TensorMemory::Mem), None).unwrap();
791        assert_eq!(t.memory(), TensorMemory::Mem);
792        assert!(t.dmabuf_clone().is_err());
793    }
794
795    #[cfg(target_os = "linux")]
796    #[test]
797    fn dmabuf_mem_tensor_fails() {
798        let t = TensorDyn::new(&[480, 640, 3], DType::U8, Some(TensorMemory::Mem), None).unwrap();
799        assert!(t.dmabuf().is_err());
800    }
801
802    #[test]
803    fn set_format_semi_planar_nv12() {
804        // 720 rows = 480 * 3/2 (NV12: height + height/2 for chroma)
805        let mut t = TensorDyn::new(&[720, 640], DType::U8, Some(TensorMemory::Mem), None).unwrap();
806        t.set_format(PixelFormat::Nv12).unwrap();
807        assert_eq!(t.format(), Some(PixelFormat::Nv12));
808        assert_eq!(t.width(), Some(640));
809        assert_eq!(t.height(), Some(480));
810    }
811
812    #[test]
813    fn set_format_semi_planar_nv16() {
814        // 960 rows = 480 * 2 (NV16: height + height for chroma)
815        let mut t = TensorDyn::new(&[960, 640], DType::U8, Some(TensorMemory::Mem), None).unwrap();
816        t.set_format(PixelFormat::Nv16).unwrap();
817        assert_eq!(t.format(), Some(PixelFormat::Nv16));
818        assert_eq!(t.width(), Some(640));
819        assert_eq!(t.height(), Some(480));
820    }
821
822    #[test]
823    fn with_format_rejects_wrong_shape() {
824        let result = TensorDyn::new(&[480, 640, 4], DType::U8, None, None)
825            .unwrap()
826            .with_format(PixelFormat::Rgb);
827        assert!(result.is_err());
828    }
829
830    #[test]
831    fn set_format_preserved_after_rejection() {
832        let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
833        t.set_format(PixelFormat::Rgb).unwrap();
834        assert_eq!(t.format(), Some(PixelFormat::Rgb));
835
836        // Rgba requires 4 channels, should fail on a 3-channel tensor
837        assert!(t.set_format(PixelFormat::Rgba).is_err());
838
839        // Original format should be preserved
840        assert_eq!(t.format(), Some(PixelFormat::Rgb));
841    }
842
843    #[test]
844    fn set_format_idempotent() {
845        let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
846        t.set_format(PixelFormat::Rgb).unwrap();
847        t.set_format(PixelFormat::Rgb).unwrap();
848        assert_eq!(t.format(), Some(PixelFormat::Rgb));
849        assert_eq!(t.width(), Some(640));
850        assert_eq!(t.height(), Some(480));
851    }
852
853    // --- Row stride tests ---
854
855    #[test]
856    fn set_row_stride_valid() {
857        // RGBA 100px wide: min stride = 400, set 512
858        let mut t = TensorDyn::image(100, 100, PixelFormat::Rgba, DType::U8, None).unwrap();
859        t.set_row_stride(512).unwrap();
860        assert_eq!(t.row_stride(), Some(512));
861        assert_eq!(t.effective_row_stride(), Some(512));
862    }
863
864    #[test]
865    fn set_row_stride_equals_min() {
866        // RGB 100px: min stride = 300, set exactly 300
867        let mut t = TensorDyn::image(100, 100, PixelFormat::Rgb, DType::U8, None).unwrap();
868        t.set_row_stride(300).unwrap();
869        assert_eq!(t.row_stride(), Some(300));
870    }
871
872    #[test]
873    fn set_row_stride_too_small() {
874        // RGBA 100px: min stride = 400, set 300
875        let mut t = TensorDyn::image(100, 100, PixelFormat::Rgba, DType::U8, None).unwrap();
876        assert!(t.set_row_stride(300).is_err());
877        assert_eq!(t.row_stride(), None);
878    }
879
880    #[test]
881    fn set_row_stride_zero() {
882        let mut t = TensorDyn::image(100, 100, PixelFormat::Rgb, DType::U8, None).unwrap();
883        assert!(t.set_row_stride(0).is_err());
884    }
885
886    #[test]
887    fn set_row_stride_requires_format() {
888        let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
889        assert!(t.set_row_stride(2048).is_err());
890    }
891
892    #[test]
893    fn effective_row_stride_without_stride() {
894        let t = TensorDyn::image(100, 100, PixelFormat::Rgb, DType::U8, None).unwrap();
895        assert_eq!(t.row_stride(), None);
896        assert_eq!(t.effective_row_stride(), Some(300)); // 100 * 3
897    }
898
899    #[test]
900    fn effective_row_stride_no_format() {
901        let t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
902        assert_eq!(t.effective_row_stride(), None);
903    }
904
905    #[test]
906    fn with_row_stride_builder() {
907        let t = TensorDyn::image(100, 100, PixelFormat::Rgba, DType::U8, None)
908            .unwrap()
909            .with_row_stride(512)
910            .unwrap();
911        assert_eq!(t.row_stride(), Some(512));
912        assert_eq!(t.effective_row_stride(), Some(512));
913    }
914
915    #[test]
916    fn with_row_stride_rejects_small() {
917        let result = TensorDyn::image(100, 100, PixelFormat::Rgba, DType::U8, None)
918            .unwrap()
919            .with_row_stride(200);
920        assert!(result.is_err());
921    }
922
923    #[test]
924    fn set_format_clears_row_stride() {
925        let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
926        t.set_format(PixelFormat::Rgb).unwrap();
927        t.set_row_stride(2048).unwrap();
928        assert_eq!(t.row_stride(), Some(2048));
929
930        // Incompatible format change (4-chan on 3-chan shape) fails — stride preserved
931        let _ = t.set_format(PixelFormat::Bgra);
932        assert_eq!(t.row_stride(), Some(2048));
933
934        // Re-set to same format — stride preserved
935        t.set_format(PixelFormat::Rgb).unwrap();
936        assert_eq!(t.row_stride(), Some(2048));
937
938        // Reshape clears format and stride
939        t.reshape(&[480 * 640 * 3]).unwrap();
940        assert_eq!(t.row_stride(), None);
941        assert_eq!(t.format(), None);
942    }
943
944    #[test]
945    fn set_format_different_compatible_clears_stride() {
946        // RGBA and BGRA are both 4-channel packed — switching between them
947        // succeeds and must clear the stored stride.
948        let mut t = TensorDyn::new(&[480, 640, 4], DType::U8, None, None).unwrap();
949        t.set_format(PixelFormat::Rgba).unwrap();
950        t.set_row_stride(4096).unwrap();
951        assert_eq!(t.row_stride(), Some(4096));
952
953        // Successful format change to a different compatible format clears stride
954        t.set_format(PixelFormat::Bgra).unwrap();
955        assert_eq!(t.format(), Some(PixelFormat::Bgra));
956        assert_eq!(t.row_stride(), None);
957    }
958
959    #[test]
960    fn set_format_same_preserves_stride() {
961        let mut t = TensorDyn::image(100, 100, PixelFormat::Rgb, DType::U8, None).unwrap();
962        t.set_row_stride(512).unwrap();
963        // Re-setting the same format should not clear stride
964        t.set_format(PixelFormat::Rgb).unwrap();
965        assert_eq!(t.row_stride(), Some(512));
966    }
967
968    #[test]
969    fn effective_row_stride_planar() {
970        let t = TensorDyn::image(640, 480, PixelFormat::PlanarRgb, DType::U8, None).unwrap();
971        assert_eq!(t.effective_row_stride(), Some(640)); // planar: width only
972    }
973
974    #[test]
975    fn effective_row_stride_nv12() {
976        let t = TensorDyn::image(640, 480, PixelFormat::Nv12, DType::U8, None).unwrap();
977        assert_eq!(t.effective_row_stride(), Some(640)); // semi-planar: width only
978    }
979
980    #[test]
981    fn map_rejects_strided_tensor() {
982        let mut t =
983            Tensor::<u8>::image(100, 100, PixelFormat::Rgba, Some(TensorMemory::Mem)).unwrap();
984        // Map works before stride is set
985        assert!(t.map().is_ok());
986        // After setting stride, map should be rejected
987        t.set_row_stride(512).unwrap();
988        let err = t.map();
989        assert!(err.is_err());
990    }
991
992    // ── plane_offset tests ──────────────────────────────────────────
993
994    #[test]
995    fn plane_offset_default_none() {
996        let t = TensorDyn::image(100, 100, PixelFormat::Rgba, DType::U8, None).unwrap();
997        assert_eq!(t.plane_offset(), None);
998    }
999
1000    #[test]
1001    fn set_plane_offset_basic() {
1002        let mut t = TensorDyn::image(100, 100, PixelFormat::Rgba, DType::U8, None).unwrap();
1003        t.set_plane_offset(4096);
1004        assert_eq!(t.plane_offset(), Some(4096));
1005    }
1006
1007    #[test]
1008    fn set_plane_offset_zero() {
1009        let mut t = TensorDyn::image(100, 100, PixelFormat::Rgb, DType::U8, None).unwrap();
1010        t.set_plane_offset(0);
1011        assert_eq!(t.plane_offset(), Some(0));
1012    }
1013
1014    #[test]
1015    fn set_plane_offset_no_format() {
1016        // plane_offset does not require format (it is format-independent)
1017        let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
1018        t.set_plane_offset(4096);
1019        assert_eq!(t.plane_offset(), Some(4096));
1020    }
1021
1022    #[test]
1023    fn with_plane_offset_builder() {
1024        let t = TensorDyn::image(100, 100, PixelFormat::Rgba, DType::U8, None)
1025            .unwrap()
1026            .with_plane_offset(8192);
1027        assert_eq!(t.plane_offset(), Some(8192));
1028    }
1029
1030    #[test]
1031    fn set_format_clears_plane_offset() {
1032        let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
1033        t.set_format(PixelFormat::Rgb).unwrap();
1034        t.set_plane_offset(4096);
1035        assert_eq!(t.plane_offset(), Some(4096));
1036
1037        // Re-set same format — offset preserved
1038        t.set_format(PixelFormat::Rgb).unwrap();
1039        assert_eq!(t.plane_offset(), Some(4096));
1040
1041        // Reshape clears everything
1042        t.reshape(&[480 * 640 * 3]).unwrap();
1043        assert_eq!(t.plane_offset(), None);
1044        assert_eq!(t.format(), None);
1045    }
1046
1047    #[test]
1048    fn map_rejects_offset_tensor() {
1049        let mut t =
1050            Tensor::<u8>::image(100, 100, PixelFormat::Rgba, Some(TensorMemory::Mem)).unwrap();
1051        // Map works before offset is set
1052        assert!(t.map().is_ok());
1053        // After setting non-zero offset, map should be rejected
1054        t.set_plane_offset(4096);
1055        assert!(t.map().is_err());
1056    }
1057
1058    #[test]
1059    fn map_accepts_zero_offset_tensor() {
1060        let mut t =
1061            Tensor::<u8>::image(100, 100, PixelFormat::Rgba, Some(TensorMemory::Mem)).unwrap();
1062        t.set_plane_offset(0);
1063        // Zero offset is fine for CPU mapping
1064        assert!(t.map().is_ok());
1065    }
1066
1067    #[test]
1068    fn from_planes_propagates_plane_offset() {
1069        let mut luma =
1070            Tensor::<u8>::new(&[480, 640], Some(TensorMemory::Mem), Some("luma")).unwrap();
1071        luma.set_plane_offset(4096);
1072        let chroma =
1073            Tensor::<u8>::new(&[240, 640], Some(TensorMemory::Mem), Some("chroma")).unwrap();
1074        let combined = Tensor::<u8>::from_planes(luma, chroma, PixelFormat::Nv12).unwrap();
1075        assert_eq!(combined.plane_offset(), Some(4096));
1076    }
1077}