Skip to main content

edgefirst_tensor/
tensor_dyn.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{DType, PixelFormat, Tensor, TensorMemory, TensorTrait};
5use half::f16;
6use std::fmt;
7
8/// Type-erased tensor. Wraps a `Tensor<T>` with runtime element type.
9#[non_exhaustive]
10pub enum TensorDyn {
11    /// Unsigned 8-bit integer tensor.
12    U8(Tensor<u8>),
13    /// Signed 8-bit integer tensor.
14    I8(Tensor<i8>),
15    /// Unsigned 16-bit integer tensor.
16    U16(Tensor<u16>),
17    /// Signed 16-bit integer tensor.
18    I16(Tensor<i16>),
19    /// Unsigned 32-bit integer tensor.
20    U32(Tensor<u32>),
21    /// Signed 32-bit integer tensor.
22    I32(Tensor<i32>),
23    /// Unsigned 64-bit integer tensor.
24    U64(Tensor<u64>),
25    /// Signed 64-bit integer tensor.
26    I64(Tensor<i64>),
27    /// 16-bit floating-point tensor.
28    F16(Tensor<f16>),
29    /// 32-bit floating-point tensor.
30    F32(Tensor<f32>),
31    /// 64-bit floating-point tensor.
32    F64(Tensor<f64>),
33}
34
35/// Dispatch a method call across all TensorDyn variants.
36macro_rules! dispatch {
37    ($self:expr, $method:ident $(, $arg:expr)*) => {
38        match $self {
39            TensorDyn::U8(t) => t.$method($($arg),*),
40            TensorDyn::I8(t) => t.$method($($arg),*),
41            TensorDyn::U16(t) => t.$method($($arg),*),
42            TensorDyn::I16(t) => t.$method($($arg),*),
43            TensorDyn::U32(t) => t.$method($($arg),*),
44            TensorDyn::I32(t) => t.$method($($arg),*),
45            TensorDyn::U64(t) => t.$method($($arg),*),
46            TensorDyn::I64(t) => t.$method($($arg),*),
47            TensorDyn::F16(t) => t.$method($($arg),*),
48            TensorDyn::F32(t) => t.$method($($arg),*),
49            TensorDyn::F64(t) => t.$method($($arg),*),
50        }
51    };
52}
53
54/// Generate the three downcast methods (ref, mut ref, owned) for one variant.
55macro_rules! downcast_methods {
56    ($variant:ident, $ty:ty, $as_name:ident, $as_mut_name:ident, $into_name:ident) => {
57        /// Returns a shared reference to the inner tensor if the type matches.
58        pub fn $as_name(&self) -> Option<&Tensor<$ty>> {
59            match self {
60                Self::$variant(t) => Some(t),
61                _ => None,
62            }
63        }
64
65        /// Returns a mutable reference to the inner tensor if the type matches.
66        pub fn $as_mut_name(&mut self) -> Option<&mut Tensor<$ty>> {
67            match self {
68                Self::$variant(t) => Some(t),
69                _ => None,
70            }
71        }
72
73        /// Unwraps the inner tensor if the type matches, otherwise returns `self` as `Err`.
74        pub fn $into_name(self) -> Result<Tensor<$ty>, Self> {
75            match self {
76                Self::$variant(t) => Ok(t),
77                other => Err(other),
78            }
79        }
80    };
81}
82
83impl TensorDyn {
84    /// Return the runtime element type discriminant.
85    pub fn dtype(&self) -> DType {
86        match self {
87            Self::U8(_) => DType::U8,
88            Self::I8(_) => DType::I8,
89            Self::U16(_) => DType::U16,
90            Self::I16(_) => DType::I16,
91            Self::U32(_) => DType::U32,
92            Self::I32(_) => DType::I32,
93            Self::U64(_) => DType::U64,
94            Self::I64(_) => DType::I64,
95            Self::F16(_) => DType::F16,
96            Self::F32(_) => DType::F32,
97            Self::F64(_) => DType::F64,
98        }
99    }
100
101    /// Return the tensor shape.
102    pub fn shape(&self) -> &[usize] {
103        dispatch!(self, shape)
104    }
105
106    /// Return the tensor name.
107    pub fn name(&self) -> String {
108        dispatch!(self, name)
109    }
110
111    /// Return the pixel format (None if not an image tensor).
112    pub fn format(&self) -> Option<PixelFormat> {
113        dispatch!(self, format)
114    }
115
116    /// Return the image width (None if not an image tensor).
117    pub fn width(&self) -> Option<usize> {
118        dispatch!(self, width)
119    }
120
121    /// Return the image height (None if not an image tensor).
122    pub fn height(&self) -> Option<usize> {
123        dispatch!(self, height)
124    }
125
126    /// Return the total size of this tensor in bytes.
127    pub fn size(&self) -> usize {
128        dispatch!(self, size)
129    }
130
131    /// Return the memory allocation type.
132    pub fn memory(&self) -> TensorMemory {
133        dispatch!(self, memory)
134    }
135
136    /// Reshape this tensor. Total element count must remain the same.
137    pub fn reshape(&mut self, shape: &[usize]) -> crate::Result<()> {
138        dispatch!(self, reshape, shape)
139    }
140
141    /// Attach pixel format metadata to this tensor.
142    ///
143    /// Validates that the tensor's shape is compatible with the format's
144    /// layout (packed, planar, or semi-planar).
145    ///
146    /// # Arguments
147    ///
148    /// * `format` - The pixel format to attach
149    ///
150    /// # Returns
151    ///
152    /// `Ok(())` on success, with the format stored as metadata on the tensor.
153    ///
154    /// # Errors
155    ///
156    /// Returns `Error::InvalidShape` if the tensor shape doesn't match
157    /// the expected layout for the given format.
158    pub fn set_format(&mut self, format: PixelFormat) -> crate::Result<()> {
159        dispatch!(self, set_format, format)
160    }
161
162    /// Attach pixel format metadata, consuming and returning self.
163    ///
164    /// Enables builder-style chaining.
165    ///
166    /// # Arguments
167    ///
168    /// * `format` - The pixel format to attach
169    ///
170    /// # Returns
171    ///
172    /// The tensor with format metadata attached.
173    ///
174    /// # Errors
175    ///
176    /// Returns `Error::InvalidShape` if the tensor shape doesn't match
177    /// the expected layout for the given format.
178    pub fn with_format(mut self, format: PixelFormat) -> crate::Result<Self> {
179        self.set_format(format)?;
180        Ok(self)
181    }
182
183    /// Clone the file descriptor associated with this tensor.
184    #[cfg(unix)]
185    pub fn clone_fd(&self) -> crate::Result<std::os::fd::OwnedFd> {
186        dispatch!(self, clone_fd)
187    }
188
189    /// Clone the DMA-BUF file descriptor backing this tensor (Linux only).
190    ///
191    /// # Returns
192    ///
193    /// An owned duplicate of the DMA-BUF file descriptor.
194    ///
195    /// # Errors
196    ///
197    /// * `Error::NotImplemented` if the tensor is not DMA-backed (Mem/Shm/Pbo)
198    /// * `Error::IoError` if the fd clone syscall fails (e.g., fd limit reached)
199    #[cfg(target_os = "linux")]
200    pub fn dmabuf_clone(&self) -> crate::Result<std::os::fd::OwnedFd> {
201        if self.memory() != TensorMemory::Dma {
202            return Err(crate::Error::NotImplemented(format!(
203                "dmabuf_clone requires DMA-backed tensor, got {:?}",
204                self.memory()
205            )));
206        }
207        self.clone_fd()
208    }
209
210    /// Borrow the DMA-BUF file descriptor backing this tensor (Linux only).
211    ///
212    /// # Returns
213    ///
214    /// A borrowed reference to the DMA-BUF file descriptor, tied to `self`'s
215    /// lifetime.
216    ///
217    /// # Errors
218    ///
219    /// * `Error::NotImplemented` if the tensor is not DMA-backed
220    #[cfg(target_os = "linux")]
221    pub fn dmabuf(&self) -> crate::Result<std::os::fd::BorrowedFd<'_>> {
222        dispatch!(self, dmabuf)
223    }
224
225    /// Return `true` if this tensor uses separate plane allocations.
226    pub fn is_multiplane(&self) -> bool {
227        dispatch!(self, is_multiplane)
228    }
229
230    // --- Downcasting ---
231
232    downcast_methods!(U8, u8, as_u8, as_u8_mut, into_u8);
233    downcast_methods!(I8, i8, as_i8, as_i8_mut, into_i8);
234    downcast_methods!(U16, u16, as_u16, as_u16_mut, into_u16);
235    downcast_methods!(I16, i16, as_i16, as_i16_mut, into_i16);
236    downcast_methods!(U32, u32, as_u32, as_u32_mut, into_u32);
237    downcast_methods!(I32, i32, as_i32, as_i32_mut, into_i32);
238    downcast_methods!(U64, u64, as_u64, as_u64_mut, into_u64);
239    downcast_methods!(I64, i64, as_i64, as_i64_mut, into_i64);
240    downcast_methods!(F16, f16, as_f16, as_f16_mut, into_f16);
241    downcast_methods!(F32, f32, as_f32, as_f32_mut, into_f32);
242    downcast_methods!(F64, f64, as_f64, as_f64_mut, into_f64);
243
244    /// Create a type-erased tensor with the given shape and element type.
245    pub fn new(
246        shape: &[usize],
247        dtype: DType,
248        memory: Option<TensorMemory>,
249        name: Option<&str>,
250    ) -> crate::Result<Self> {
251        match dtype {
252            DType::U8 => Tensor::<u8>::new(shape, memory, name).map(Self::U8),
253            DType::I8 => Tensor::<i8>::new(shape, memory, name).map(Self::I8),
254            DType::U16 => Tensor::<u16>::new(shape, memory, name).map(Self::U16),
255            DType::I16 => Tensor::<i16>::new(shape, memory, name).map(Self::I16),
256            DType::U32 => Tensor::<u32>::new(shape, memory, name).map(Self::U32),
257            DType::I32 => Tensor::<i32>::new(shape, memory, name).map(Self::I32),
258            DType::U64 => Tensor::<u64>::new(shape, memory, name).map(Self::U64),
259            DType::I64 => Tensor::<i64>::new(shape, memory, name).map(Self::I64),
260            DType::F16 => Tensor::<f16>::new(shape, memory, name).map(Self::F16),
261            DType::F32 => Tensor::<f32>::new(shape, memory, name).map(Self::F32),
262            DType::F64 => Tensor::<f64>::new(shape, memory, name).map(Self::F64),
263        }
264    }
265
266    /// Create a type-erased tensor from a file descriptor.
267    #[cfg(unix)]
268    pub fn from_fd(
269        fd: std::os::fd::OwnedFd,
270        shape: &[usize],
271        dtype: DType,
272        name: Option<&str>,
273    ) -> crate::Result<Self> {
274        match dtype {
275            DType::U8 => Tensor::<u8>::from_fd(fd, shape, name).map(Self::U8),
276            DType::I8 => Tensor::<i8>::from_fd(fd, shape, name).map(Self::I8),
277            DType::U16 => Tensor::<u16>::from_fd(fd, shape, name).map(Self::U16),
278            DType::I16 => Tensor::<i16>::from_fd(fd, shape, name).map(Self::I16),
279            DType::U32 => Tensor::<u32>::from_fd(fd, shape, name).map(Self::U32),
280            DType::I32 => Tensor::<i32>::from_fd(fd, shape, name).map(Self::I32),
281            DType::U64 => Tensor::<u64>::from_fd(fd, shape, name).map(Self::U64),
282            DType::I64 => Tensor::<i64>::from_fd(fd, shape, name).map(Self::I64),
283            DType::F16 => Tensor::<f16>::from_fd(fd, shape, name).map(Self::F16),
284            DType::F32 => Tensor::<f32>::from_fd(fd, shape, name).map(Self::F32),
285            DType::F64 => Tensor::<f64>::from_fd(fd, shape, name).map(Self::F64),
286        }
287    }
288
289    /// Create a type-erased image tensor.
290    ///
291    /// # Arguments
292    ///
293    /// * `width` - Image width in pixels
294    /// * `height` - Image height in pixels
295    /// * `format` - Pixel format
296    /// * `dtype` - Element type discriminant
297    /// * `memory` - Optional memory backend (None selects the best available)
298    ///
299    /// # Returns
300    ///
301    /// A new `TensorDyn` wrapping an image tensor of the requested element type.
302    ///
303    /// # Errors
304    ///
305    /// Returns an error if the underlying `Tensor::image` call fails.
306    pub fn image(
307        width: usize,
308        height: usize,
309        format: PixelFormat,
310        dtype: DType,
311        memory: Option<TensorMemory>,
312    ) -> crate::Result<Self> {
313        match dtype {
314            DType::U8 => Tensor::<u8>::image(width, height, format, memory).map(Self::U8),
315            DType::I8 => Tensor::<i8>::image(width, height, format, memory).map(Self::I8),
316            DType::U16 => Tensor::<u16>::image(width, height, format, memory).map(Self::U16),
317            DType::I16 => Tensor::<i16>::image(width, height, format, memory).map(Self::I16),
318            DType::U32 => Tensor::<u32>::image(width, height, format, memory).map(Self::U32),
319            DType::I32 => Tensor::<i32>::image(width, height, format, memory).map(Self::I32),
320            DType::U64 => Tensor::<u64>::image(width, height, format, memory).map(Self::U64),
321            DType::I64 => Tensor::<i64>::image(width, height, format, memory).map(Self::I64),
322            DType::F16 => Tensor::<f16>::image(width, height, format, memory).map(Self::F16),
323            DType::F32 => Tensor::<f32>::image(width, height, format, memory).map(Self::F32),
324            DType::F64 => Tensor::<f64>::image(width, height, format, memory).map(Self::F64),
325        }
326    }
327}
328
329// --- From impls ---
330
331impl From<Tensor<u8>> for TensorDyn {
332    fn from(t: Tensor<u8>) -> Self {
333        Self::U8(t)
334    }
335}
336
337impl From<Tensor<i8>> for TensorDyn {
338    fn from(t: Tensor<i8>) -> Self {
339        Self::I8(t)
340    }
341}
342
343impl From<Tensor<u16>> for TensorDyn {
344    fn from(t: Tensor<u16>) -> Self {
345        Self::U16(t)
346    }
347}
348
349impl From<Tensor<i16>> for TensorDyn {
350    fn from(t: Tensor<i16>) -> Self {
351        Self::I16(t)
352    }
353}
354
355impl From<Tensor<u32>> for TensorDyn {
356    fn from(t: Tensor<u32>) -> Self {
357        Self::U32(t)
358    }
359}
360
361impl From<Tensor<i32>> for TensorDyn {
362    fn from(t: Tensor<i32>) -> Self {
363        Self::I32(t)
364    }
365}
366
367impl From<Tensor<u64>> for TensorDyn {
368    fn from(t: Tensor<u64>) -> Self {
369        Self::U64(t)
370    }
371}
372
373impl From<Tensor<i64>> for TensorDyn {
374    fn from(t: Tensor<i64>) -> Self {
375        Self::I64(t)
376    }
377}
378
379impl From<Tensor<f16>> for TensorDyn {
380    fn from(t: Tensor<f16>) -> Self {
381        Self::F16(t)
382    }
383}
384
385impl From<Tensor<f32>> for TensorDyn {
386    fn from(t: Tensor<f32>) -> Self {
387        Self::F32(t)
388    }
389}
390
391impl From<Tensor<f64>> for TensorDyn {
392    fn from(t: Tensor<f64>) -> Self {
393        Self::F64(t)
394    }
395}
396
397impl fmt::Debug for TensorDyn {
398    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
399        dispatch!(self, fmt, f)
400    }
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406
407    #[test]
408    fn from_typed_tensor() {
409        let t = Tensor::<u8>::new(&[10], None, None).unwrap();
410        let dyn_t: TensorDyn = t.into();
411        assert_eq!(dyn_t.dtype(), DType::U8);
412        assert_eq!(dyn_t.shape(), &[10]);
413    }
414
415    #[test]
416    fn downcast_ref() {
417        let t = Tensor::<u8>::new(&[10], None, None).unwrap();
418        let dyn_t: TensorDyn = t.into();
419        assert!(dyn_t.as_u8().is_some());
420        assert!(dyn_t.as_i8().is_none());
421    }
422
423    #[test]
424    fn downcast_into() {
425        let t = Tensor::<u8>::new(&[10], None, None).unwrap();
426        let dyn_t: TensorDyn = t.into();
427        let back = dyn_t.into_u8().unwrap();
428        assert_eq!(back.shape(), &[10]);
429    }
430
431    #[test]
432    fn image_accessors() {
433        let t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
434        let dyn_t: TensorDyn = t.into();
435        assert_eq!(dyn_t.format(), Some(PixelFormat::Rgba));
436        assert_eq!(dyn_t.width(), Some(640));
437        assert_eq!(dyn_t.height(), Some(480));
438        assert!(!dyn_t.is_multiplane());
439    }
440
441    #[test]
442    fn image_constructor() {
443        let dyn_t = TensorDyn::image(640, 480, PixelFormat::Rgb, DType::U8, None).unwrap();
444        assert_eq!(dyn_t.dtype(), DType::U8);
445        assert_eq!(dyn_t.format(), Some(PixelFormat::Rgb));
446        assert_eq!(dyn_t.width(), Some(640));
447    }
448
449    #[test]
450    fn image_constructor_i8() {
451        let dyn_t = TensorDyn::image(640, 480, PixelFormat::Rgb, DType::I8, None).unwrap();
452        assert_eq!(dyn_t.dtype(), DType::I8);
453        assert_eq!(dyn_t.format(), Some(PixelFormat::Rgb));
454    }
455
456    #[test]
457    fn set_format_packed() {
458        let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
459        assert_eq!(t.format(), None);
460        t.set_format(PixelFormat::Rgb).unwrap();
461        assert_eq!(t.format(), Some(PixelFormat::Rgb));
462        assert_eq!(t.width(), Some(640));
463        assert_eq!(t.height(), Some(480));
464    }
465
466    #[test]
467    fn set_format_planar() {
468        let mut t = TensorDyn::new(&[3, 480, 640], DType::U8, None, None).unwrap();
469        t.set_format(PixelFormat::PlanarRgb).unwrap();
470        assert_eq!(t.format(), Some(PixelFormat::PlanarRgb));
471        assert_eq!(t.width(), Some(640));
472        assert_eq!(t.height(), Some(480));
473    }
474
475    #[test]
476    fn set_format_rejects_wrong_shape() {
477        let mut t = TensorDyn::new(&[480, 640, 4], DType::U8, None, None).unwrap();
478        assert!(t.set_format(PixelFormat::Rgb).is_err());
479    }
480
481    #[test]
482    fn with_format_builder() {
483        let t = TensorDyn::new(&[480, 640, 4], DType::U8, None, None)
484            .unwrap()
485            .with_format(PixelFormat::Rgba)
486            .unwrap();
487        assert_eq!(t.format(), Some(PixelFormat::Rgba));
488        assert_eq!(t.width(), Some(640));
489        assert_eq!(t.height(), Some(480));
490    }
491
492    #[cfg(target_os = "linux")]
493    #[test]
494    fn dmabuf_clone_mem_tensor_fails() {
495        let t = TensorDyn::new(&[480, 640, 3], DType::U8, Some(TensorMemory::Mem), None).unwrap();
496        assert_eq!(t.memory(), TensorMemory::Mem);
497        assert!(t.dmabuf_clone().is_err());
498    }
499
500    #[cfg(target_os = "linux")]
501    #[test]
502    fn dmabuf_mem_tensor_fails() {
503        let t = TensorDyn::new(&[480, 640, 3], DType::U8, Some(TensorMemory::Mem), None).unwrap();
504        assert!(t.dmabuf().is_err());
505    }
506
507    #[test]
508    fn set_format_semi_planar_nv12() {
509        // 720 rows = 480 * 3/2 (NV12: height + height/2 for chroma)
510        let mut t = TensorDyn::new(&[720, 640], DType::U8, Some(TensorMemory::Mem), None).unwrap();
511        t.set_format(PixelFormat::Nv12).unwrap();
512        assert_eq!(t.format(), Some(PixelFormat::Nv12));
513        assert_eq!(t.width(), Some(640));
514        assert_eq!(t.height(), Some(480));
515    }
516
517    #[test]
518    fn set_format_semi_planar_nv16() {
519        // 960 rows = 480 * 2 (NV16: height + height for chroma)
520        let mut t = TensorDyn::new(&[960, 640], DType::U8, Some(TensorMemory::Mem), None).unwrap();
521        t.set_format(PixelFormat::Nv16).unwrap();
522        assert_eq!(t.format(), Some(PixelFormat::Nv16));
523        assert_eq!(t.width(), Some(640));
524        assert_eq!(t.height(), Some(480));
525    }
526
527    #[test]
528    fn with_format_rejects_wrong_shape() {
529        let result = TensorDyn::new(&[480, 640, 4], DType::U8, None, None)
530            .unwrap()
531            .with_format(PixelFormat::Rgb);
532        assert!(result.is_err());
533    }
534
535    #[test]
536    fn set_format_preserved_after_rejection() {
537        let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
538        t.set_format(PixelFormat::Rgb).unwrap();
539        assert_eq!(t.format(), Some(PixelFormat::Rgb));
540
541        // Rgba requires 4 channels, should fail on a 3-channel tensor
542        assert!(t.set_format(PixelFormat::Rgba).is_err());
543
544        // Original format should be preserved
545        assert_eq!(t.format(), Some(PixelFormat::Rgb));
546    }
547
548    #[test]
549    fn set_format_idempotent() {
550        let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
551        t.set_format(PixelFormat::Rgb).unwrap();
552        t.set_format(PixelFormat::Rgb).unwrap();
553        assert_eq!(t.format(), Some(PixelFormat::Rgb));
554        assert_eq!(t.width(), Some(640));
555        assert_eq!(t.height(), Some(480));
556    }
557}