Skip to main content

edgefirst_tensor/
tensor_dyn.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{DType, PixelFormat, Tensor, TensorMemory, TensorTrait};
5use half::f16;
6use std::fmt;
7
8/// Type-erased tensor. Wraps a `Tensor<T>` with runtime element type.
9#[non_exhaustive]
10pub enum TensorDyn {
11    /// Unsigned 8-bit integer tensor.
12    U8(Tensor<u8>),
13    /// Signed 8-bit integer tensor.
14    I8(Tensor<i8>),
15    /// Unsigned 16-bit integer tensor.
16    U16(Tensor<u16>),
17    /// Signed 16-bit integer tensor.
18    I16(Tensor<i16>),
19    /// Unsigned 32-bit integer tensor.
20    U32(Tensor<u32>),
21    /// Signed 32-bit integer tensor.
22    I32(Tensor<i32>),
23    /// Unsigned 64-bit integer tensor.
24    U64(Tensor<u64>),
25    /// Signed 64-bit integer tensor.
26    I64(Tensor<i64>),
27    /// 16-bit floating-point tensor.
28    F16(Tensor<f16>),
29    /// 32-bit floating-point tensor.
30    F32(Tensor<f32>),
31    /// 64-bit floating-point tensor.
32    F64(Tensor<f64>),
33}
34
35/// Dispatch a method call across all TensorDyn variants.
36macro_rules! dispatch {
37    ($self:expr, $method:ident $(, $arg:expr)*) => {
38        match $self {
39            TensorDyn::U8(t) => t.$method($($arg),*),
40            TensorDyn::I8(t) => t.$method($($arg),*),
41            TensorDyn::U16(t) => t.$method($($arg),*),
42            TensorDyn::I16(t) => t.$method($($arg),*),
43            TensorDyn::U32(t) => t.$method($($arg),*),
44            TensorDyn::I32(t) => t.$method($($arg),*),
45            TensorDyn::U64(t) => t.$method($($arg),*),
46            TensorDyn::I64(t) => t.$method($($arg),*),
47            TensorDyn::F16(t) => t.$method($($arg),*),
48            TensorDyn::F32(t) => t.$method($($arg),*),
49            TensorDyn::F64(t) => t.$method($($arg),*),
50        }
51    };
52}
53
54/// Generate the three downcast methods (ref, mut ref, owned) for one variant.
55macro_rules! downcast_methods {
56    ($variant:ident, $ty:ty, $as_name:ident, $as_mut_name:ident, $into_name:ident) => {
57        /// Returns a shared reference to the inner tensor if the type matches.
58        pub fn $as_name(&self) -> Option<&Tensor<$ty>> {
59            match self {
60                Self::$variant(t) => Some(t),
61                _ => None,
62            }
63        }
64
65        /// Returns a mutable reference to the inner tensor if the type matches.
66        pub fn $as_mut_name(&mut self) -> Option<&mut Tensor<$ty>> {
67            match self {
68                Self::$variant(t) => Some(t),
69                _ => None,
70            }
71        }
72
73        /// Unwraps the inner tensor if the type matches, otherwise returns `self` as `Err`.
74        pub fn $into_name(self) -> Result<Tensor<$ty>, Self> {
75            match self {
76                Self::$variant(t) => Ok(t),
77                other => Err(other),
78            }
79        }
80    };
81}
82
83impl TensorDyn {
84    /// Return the runtime element type discriminant.
85    pub fn dtype(&self) -> DType {
86        match self {
87            Self::U8(_) => DType::U8,
88            Self::I8(_) => DType::I8,
89            Self::U16(_) => DType::U16,
90            Self::I16(_) => DType::I16,
91            Self::U32(_) => DType::U32,
92            Self::I32(_) => DType::I32,
93            Self::U64(_) => DType::U64,
94            Self::I64(_) => DType::I64,
95            Self::F16(_) => DType::F16,
96            Self::F32(_) => DType::F32,
97            Self::F64(_) => DType::F64,
98        }
99    }
100
101    /// Return the tensor shape.
102    pub fn shape(&self) -> &[usize] {
103        dispatch!(self, shape)
104    }
105
106    /// Return the tensor name.
107    pub fn name(&self) -> String {
108        dispatch!(self, name)
109    }
110
111    /// Return the pixel format (None if not an image tensor).
112    pub fn format(&self) -> Option<PixelFormat> {
113        dispatch!(self, format)
114    }
115
116    /// Return the image width (None if not an image tensor).
117    pub fn width(&self) -> Option<usize> {
118        dispatch!(self, width)
119    }
120
121    /// Return the image height (None if not an image tensor).
122    pub fn height(&self) -> Option<usize> {
123        dispatch!(self, height)
124    }
125
126    /// Return the total size of this tensor in bytes.
127    pub fn size(&self) -> usize {
128        dispatch!(self, size)
129    }
130
131    /// Return the memory allocation type.
132    pub fn memory(&self) -> TensorMemory {
133        dispatch!(self, memory)
134    }
135
136    /// Reshape this tensor. Total element count must remain the same.
137    pub fn reshape(&mut self, shape: &[usize]) -> crate::Result<()> {
138        dispatch!(self, reshape, shape)
139    }
140
141    /// Clone the file descriptor associated with this tensor.
142    #[cfg(unix)]
143    pub fn clone_fd(&self) -> crate::Result<std::os::fd::OwnedFd> {
144        dispatch!(self, clone_fd)
145    }
146
147    /// Return `true` if this tensor uses separate plane allocations.
148    pub fn is_multiplane(&self) -> bool {
149        dispatch!(self, is_multiplane)
150    }
151
152    // --- Downcasting ---
153
154    downcast_methods!(U8, u8, as_u8, as_u8_mut, into_u8);
155    downcast_methods!(I8, i8, as_i8, as_i8_mut, into_i8);
156    downcast_methods!(U16, u16, as_u16, as_u16_mut, into_u16);
157    downcast_methods!(I16, i16, as_i16, as_i16_mut, into_i16);
158    downcast_methods!(U32, u32, as_u32, as_u32_mut, into_u32);
159    downcast_methods!(I32, i32, as_i32, as_i32_mut, into_i32);
160    downcast_methods!(U64, u64, as_u64, as_u64_mut, into_u64);
161    downcast_methods!(I64, i64, as_i64, as_i64_mut, into_i64);
162    downcast_methods!(F16, f16, as_f16, as_f16_mut, into_f16);
163    downcast_methods!(F32, f32, as_f32, as_f32_mut, into_f32);
164    downcast_methods!(F64, f64, as_f64, as_f64_mut, into_f64);
165
166    /// Create a type-erased tensor with the given shape and element type.
167    pub fn new(
168        shape: &[usize],
169        dtype: DType,
170        memory: Option<TensorMemory>,
171        name: Option<&str>,
172    ) -> crate::Result<Self> {
173        match dtype {
174            DType::U8 => Tensor::<u8>::new(shape, memory, name).map(Self::U8),
175            DType::I8 => Tensor::<i8>::new(shape, memory, name).map(Self::I8),
176            DType::U16 => Tensor::<u16>::new(shape, memory, name).map(Self::U16),
177            DType::I16 => Tensor::<i16>::new(shape, memory, name).map(Self::I16),
178            DType::U32 => Tensor::<u32>::new(shape, memory, name).map(Self::U32),
179            DType::I32 => Tensor::<i32>::new(shape, memory, name).map(Self::I32),
180            DType::U64 => Tensor::<u64>::new(shape, memory, name).map(Self::U64),
181            DType::I64 => Tensor::<i64>::new(shape, memory, name).map(Self::I64),
182            DType::F16 => Tensor::<f16>::new(shape, memory, name).map(Self::F16),
183            DType::F32 => Tensor::<f32>::new(shape, memory, name).map(Self::F32),
184            DType::F64 => Tensor::<f64>::new(shape, memory, name).map(Self::F64),
185        }
186    }
187
188    /// Create a type-erased tensor from a file descriptor.
189    #[cfg(unix)]
190    pub fn from_fd(
191        fd: std::os::fd::OwnedFd,
192        shape: &[usize],
193        dtype: DType,
194        name: Option<&str>,
195    ) -> crate::Result<Self> {
196        match dtype {
197            DType::U8 => Tensor::<u8>::from_fd(fd, shape, name).map(Self::U8),
198            DType::I8 => Tensor::<i8>::from_fd(fd, shape, name).map(Self::I8),
199            DType::U16 => Tensor::<u16>::from_fd(fd, shape, name).map(Self::U16),
200            DType::I16 => Tensor::<i16>::from_fd(fd, shape, name).map(Self::I16),
201            DType::U32 => Tensor::<u32>::from_fd(fd, shape, name).map(Self::U32),
202            DType::I32 => Tensor::<i32>::from_fd(fd, shape, name).map(Self::I32),
203            DType::U64 => Tensor::<u64>::from_fd(fd, shape, name).map(Self::U64),
204            DType::I64 => Tensor::<i64>::from_fd(fd, shape, name).map(Self::I64),
205            DType::F16 => Tensor::<f16>::from_fd(fd, shape, name).map(Self::F16),
206            DType::F32 => Tensor::<f32>::from_fd(fd, shape, name).map(Self::F32),
207            DType::F64 => Tensor::<f64>::from_fd(fd, shape, name).map(Self::F64),
208        }
209    }
210
211    /// Create a type-erased image tensor.
212    ///
213    /// # Arguments
214    ///
215    /// * `width` - Image width in pixels
216    /// * `height` - Image height in pixels
217    /// * `format` - Pixel format
218    /// * `dtype` - Element type discriminant
219    /// * `memory` - Optional memory backend (None selects the best available)
220    ///
221    /// # Returns
222    ///
223    /// A new `TensorDyn` wrapping an image tensor of the requested element type.
224    ///
225    /// # Errors
226    ///
227    /// Returns an error if the underlying `Tensor::image` call fails.
228    pub fn image(
229        width: usize,
230        height: usize,
231        format: PixelFormat,
232        dtype: DType,
233        memory: Option<TensorMemory>,
234    ) -> crate::Result<Self> {
235        match dtype {
236            DType::U8 => Tensor::<u8>::image(width, height, format, memory).map(Self::U8),
237            DType::I8 => Tensor::<i8>::image(width, height, format, memory).map(Self::I8),
238            DType::U16 => Tensor::<u16>::image(width, height, format, memory).map(Self::U16),
239            DType::I16 => Tensor::<i16>::image(width, height, format, memory).map(Self::I16),
240            DType::U32 => Tensor::<u32>::image(width, height, format, memory).map(Self::U32),
241            DType::I32 => Tensor::<i32>::image(width, height, format, memory).map(Self::I32),
242            DType::U64 => Tensor::<u64>::image(width, height, format, memory).map(Self::U64),
243            DType::I64 => Tensor::<i64>::image(width, height, format, memory).map(Self::I64),
244            DType::F16 => Tensor::<f16>::image(width, height, format, memory).map(Self::F16),
245            DType::F32 => Tensor::<f32>::image(width, height, format, memory).map(Self::F32),
246            DType::F64 => Tensor::<f64>::image(width, height, format, memory).map(Self::F64),
247        }
248    }
249}
250
251// --- From impls ---
252
253impl From<Tensor<u8>> for TensorDyn {
254    fn from(t: Tensor<u8>) -> Self {
255        Self::U8(t)
256    }
257}
258
259impl From<Tensor<i8>> for TensorDyn {
260    fn from(t: Tensor<i8>) -> Self {
261        Self::I8(t)
262    }
263}
264
265impl From<Tensor<u16>> for TensorDyn {
266    fn from(t: Tensor<u16>) -> Self {
267        Self::U16(t)
268    }
269}
270
271impl From<Tensor<i16>> for TensorDyn {
272    fn from(t: Tensor<i16>) -> Self {
273        Self::I16(t)
274    }
275}
276
277impl From<Tensor<u32>> for TensorDyn {
278    fn from(t: Tensor<u32>) -> Self {
279        Self::U32(t)
280    }
281}
282
283impl From<Tensor<i32>> for TensorDyn {
284    fn from(t: Tensor<i32>) -> Self {
285        Self::I32(t)
286    }
287}
288
289impl From<Tensor<u64>> for TensorDyn {
290    fn from(t: Tensor<u64>) -> Self {
291        Self::U64(t)
292    }
293}
294
295impl From<Tensor<i64>> for TensorDyn {
296    fn from(t: Tensor<i64>) -> Self {
297        Self::I64(t)
298    }
299}
300
301impl From<Tensor<f16>> for TensorDyn {
302    fn from(t: Tensor<f16>) -> Self {
303        Self::F16(t)
304    }
305}
306
307impl From<Tensor<f32>> for TensorDyn {
308    fn from(t: Tensor<f32>) -> Self {
309        Self::F32(t)
310    }
311}
312
313impl From<Tensor<f64>> for TensorDyn {
314    fn from(t: Tensor<f64>) -> Self {
315        Self::F64(t)
316    }
317}
318
319impl fmt::Debug for TensorDyn {
320    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
321        dispatch!(self, fmt, f)
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328
329    #[test]
330    fn from_typed_tensor() {
331        let t = Tensor::<u8>::new(&[10], None, None).unwrap();
332        let dyn_t: TensorDyn = t.into();
333        assert_eq!(dyn_t.dtype(), DType::U8);
334        assert_eq!(dyn_t.shape(), &[10]);
335    }
336
337    #[test]
338    fn downcast_ref() {
339        let t = Tensor::<u8>::new(&[10], None, None).unwrap();
340        let dyn_t: TensorDyn = t.into();
341        assert!(dyn_t.as_u8().is_some());
342        assert!(dyn_t.as_i8().is_none());
343    }
344
345    #[test]
346    fn downcast_into() {
347        let t = Tensor::<u8>::new(&[10], None, None).unwrap();
348        let dyn_t: TensorDyn = t.into();
349        let back = dyn_t.into_u8().unwrap();
350        assert_eq!(back.shape(), &[10]);
351    }
352
353    #[test]
354    fn image_accessors() {
355        let t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
356        let dyn_t: TensorDyn = t.into();
357        assert_eq!(dyn_t.format(), Some(PixelFormat::Rgba));
358        assert_eq!(dyn_t.width(), Some(640));
359        assert_eq!(dyn_t.height(), Some(480));
360        assert!(!dyn_t.is_multiplane());
361    }
362
363    #[test]
364    fn image_constructor() {
365        let dyn_t = TensorDyn::image(640, 480, PixelFormat::Rgb, DType::U8, None).unwrap();
366        assert_eq!(dyn_t.dtype(), DType::U8);
367        assert_eq!(dyn_t.format(), Some(PixelFormat::Rgb));
368        assert_eq!(dyn_t.width(), Some(640));
369    }
370
371    #[test]
372    fn image_constructor_i8() {
373        let dyn_t = TensorDyn::image(640, 480, PixelFormat::Rgb, DType::I8, None).unwrap();
374        assert_eq!(dyn_t.dtype(), DType::I8);
375        assert_eq!(dyn_t.format(), Some(PixelFormat::Rgb));
376    }
377}