1use crate::{DType, PixelFormat, Tensor, TensorMemory, TensorTrait};
5use half::f16;
6use std::fmt;
7
8#[non_exhaustive]
10pub enum TensorDyn {
11 U8(Tensor<u8>),
13 I8(Tensor<i8>),
15 U16(Tensor<u16>),
17 I16(Tensor<i16>),
19 U32(Tensor<u32>),
21 I32(Tensor<i32>),
23 U64(Tensor<u64>),
25 I64(Tensor<i64>),
27 F16(Tensor<f16>),
29 F32(Tensor<f32>),
31 F64(Tensor<f64>),
33}
34
35macro_rules! dispatch {
37 ($self:expr, $method:ident $(, $arg:expr)*) => {
38 match $self {
39 TensorDyn::U8(t) => t.$method($($arg),*),
40 TensorDyn::I8(t) => t.$method($($arg),*),
41 TensorDyn::U16(t) => t.$method($($arg),*),
42 TensorDyn::I16(t) => t.$method($($arg),*),
43 TensorDyn::U32(t) => t.$method($($arg),*),
44 TensorDyn::I32(t) => t.$method($($arg),*),
45 TensorDyn::U64(t) => t.$method($($arg),*),
46 TensorDyn::I64(t) => t.$method($($arg),*),
47 TensorDyn::F16(t) => t.$method($($arg),*),
48 TensorDyn::F32(t) => t.$method($($arg),*),
49 TensorDyn::F64(t) => t.$method($($arg),*),
50 }
51 };
52}
53
54macro_rules! downcast_methods {
56 ($variant:ident, $ty:ty, $as_name:ident, $as_mut_name:ident, $into_name:ident) => {
57 pub fn $as_name(&self) -> Option<&Tensor<$ty>> {
59 match self {
60 Self::$variant(t) => Some(t),
61 _ => None,
62 }
63 }
64
65 pub fn $as_mut_name(&mut self) -> Option<&mut Tensor<$ty>> {
67 match self {
68 Self::$variant(t) => Some(t),
69 _ => None,
70 }
71 }
72
73 pub fn $into_name(self) -> Result<Tensor<$ty>, Self> {
75 match self {
76 Self::$variant(t) => Ok(t),
77 other => Err(other),
78 }
79 }
80 };
81}
82
83impl TensorDyn {
84 pub fn dtype(&self) -> DType {
86 match self {
87 Self::U8(_) => DType::U8,
88 Self::I8(_) => DType::I8,
89 Self::U16(_) => DType::U16,
90 Self::I16(_) => DType::I16,
91 Self::U32(_) => DType::U32,
92 Self::I32(_) => DType::I32,
93 Self::U64(_) => DType::U64,
94 Self::I64(_) => DType::I64,
95 Self::F16(_) => DType::F16,
96 Self::F32(_) => DType::F32,
97 Self::F64(_) => DType::F64,
98 }
99 }
100
101 pub fn shape(&self) -> &[usize] {
103 dispatch!(self, shape)
104 }
105
106 pub fn name(&self) -> String {
108 dispatch!(self, name)
109 }
110
111 pub fn format(&self) -> Option<PixelFormat> {
113 dispatch!(self, format)
114 }
115
116 pub fn width(&self) -> Option<usize> {
118 dispatch!(self, width)
119 }
120
121 pub fn height(&self) -> Option<usize> {
123 dispatch!(self, height)
124 }
125
126 pub fn size(&self) -> usize {
128 dispatch!(self, size)
129 }
130
131 pub fn memory(&self) -> TensorMemory {
133 dispatch!(self, memory)
134 }
135
136 pub fn reshape(&mut self, shape: &[usize]) -> crate::Result<()> {
138 dispatch!(self, reshape, shape)
139 }
140
141 #[cfg(unix)]
143 pub fn clone_fd(&self) -> crate::Result<std::os::fd::OwnedFd> {
144 dispatch!(self, clone_fd)
145 }
146
147 pub fn is_multiplane(&self) -> bool {
149 dispatch!(self, is_multiplane)
150 }
151
152 downcast_methods!(U8, u8, as_u8, as_u8_mut, into_u8);
155 downcast_methods!(I8, i8, as_i8, as_i8_mut, into_i8);
156 downcast_methods!(U16, u16, as_u16, as_u16_mut, into_u16);
157 downcast_methods!(I16, i16, as_i16, as_i16_mut, into_i16);
158 downcast_methods!(U32, u32, as_u32, as_u32_mut, into_u32);
159 downcast_methods!(I32, i32, as_i32, as_i32_mut, into_i32);
160 downcast_methods!(U64, u64, as_u64, as_u64_mut, into_u64);
161 downcast_methods!(I64, i64, as_i64, as_i64_mut, into_i64);
162 downcast_methods!(F16, f16, as_f16, as_f16_mut, into_f16);
163 downcast_methods!(F32, f32, as_f32, as_f32_mut, into_f32);
164 downcast_methods!(F64, f64, as_f64, as_f64_mut, into_f64);
165
166 pub fn new(
168 shape: &[usize],
169 dtype: DType,
170 memory: Option<TensorMemory>,
171 name: Option<&str>,
172 ) -> crate::Result<Self> {
173 match dtype {
174 DType::U8 => Tensor::<u8>::new(shape, memory, name).map(Self::U8),
175 DType::I8 => Tensor::<i8>::new(shape, memory, name).map(Self::I8),
176 DType::U16 => Tensor::<u16>::new(shape, memory, name).map(Self::U16),
177 DType::I16 => Tensor::<i16>::new(shape, memory, name).map(Self::I16),
178 DType::U32 => Tensor::<u32>::new(shape, memory, name).map(Self::U32),
179 DType::I32 => Tensor::<i32>::new(shape, memory, name).map(Self::I32),
180 DType::U64 => Tensor::<u64>::new(shape, memory, name).map(Self::U64),
181 DType::I64 => Tensor::<i64>::new(shape, memory, name).map(Self::I64),
182 DType::F16 => Tensor::<f16>::new(shape, memory, name).map(Self::F16),
183 DType::F32 => Tensor::<f32>::new(shape, memory, name).map(Self::F32),
184 DType::F64 => Tensor::<f64>::new(shape, memory, name).map(Self::F64),
185 }
186 }
187
188 #[cfg(unix)]
190 pub fn from_fd(
191 fd: std::os::fd::OwnedFd,
192 shape: &[usize],
193 dtype: DType,
194 name: Option<&str>,
195 ) -> crate::Result<Self> {
196 match dtype {
197 DType::U8 => Tensor::<u8>::from_fd(fd, shape, name).map(Self::U8),
198 DType::I8 => Tensor::<i8>::from_fd(fd, shape, name).map(Self::I8),
199 DType::U16 => Tensor::<u16>::from_fd(fd, shape, name).map(Self::U16),
200 DType::I16 => Tensor::<i16>::from_fd(fd, shape, name).map(Self::I16),
201 DType::U32 => Tensor::<u32>::from_fd(fd, shape, name).map(Self::U32),
202 DType::I32 => Tensor::<i32>::from_fd(fd, shape, name).map(Self::I32),
203 DType::U64 => Tensor::<u64>::from_fd(fd, shape, name).map(Self::U64),
204 DType::I64 => Tensor::<i64>::from_fd(fd, shape, name).map(Self::I64),
205 DType::F16 => Tensor::<f16>::from_fd(fd, shape, name).map(Self::F16),
206 DType::F32 => Tensor::<f32>::from_fd(fd, shape, name).map(Self::F32),
207 DType::F64 => Tensor::<f64>::from_fd(fd, shape, name).map(Self::F64),
208 }
209 }
210
211 pub fn image(
229 width: usize,
230 height: usize,
231 format: PixelFormat,
232 dtype: DType,
233 memory: Option<TensorMemory>,
234 ) -> crate::Result<Self> {
235 match dtype {
236 DType::U8 => Tensor::<u8>::image(width, height, format, memory).map(Self::U8),
237 DType::I8 => Tensor::<i8>::image(width, height, format, memory).map(Self::I8),
238 DType::U16 => Tensor::<u16>::image(width, height, format, memory).map(Self::U16),
239 DType::I16 => Tensor::<i16>::image(width, height, format, memory).map(Self::I16),
240 DType::U32 => Tensor::<u32>::image(width, height, format, memory).map(Self::U32),
241 DType::I32 => Tensor::<i32>::image(width, height, format, memory).map(Self::I32),
242 DType::U64 => Tensor::<u64>::image(width, height, format, memory).map(Self::U64),
243 DType::I64 => Tensor::<i64>::image(width, height, format, memory).map(Self::I64),
244 DType::F16 => Tensor::<f16>::image(width, height, format, memory).map(Self::F16),
245 DType::F32 => Tensor::<f32>::image(width, height, format, memory).map(Self::F32),
246 DType::F64 => Tensor::<f64>::image(width, height, format, memory).map(Self::F64),
247 }
248 }
249}
250
251impl From<Tensor<u8>> for TensorDyn {
254 fn from(t: Tensor<u8>) -> Self {
255 Self::U8(t)
256 }
257}
258
259impl From<Tensor<i8>> for TensorDyn {
260 fn from(t: Tensor<i8>) -> Self {
261 Self::I8(t)
262 }
263}
264
265impl From<Tensor<u16>> for TensorDyn {
266 fn from(t: Tensor<u16>) -> Self {
267 Self::U16(t)
268 }
269}
270
271impl From<Tensor<i16>> for TensorDyn {
272 fn from(t: Tensor<i16>) -> Self {
273 Self::I16(t)
274 }
275}
276
277impl From<Tensor<u32>> for TensorDyn {
278 fn from(t: Tensor<u32>) -> Self {
279 Self::U32(t)
280 }
281}
282
283impl From<Tensor<i32>> for TensorDyn {
284 fn from(t: Tensor<i32>) -> Self {
285 Self::I32(t)
286 }
287}
288
289impl From<Tensor<u64>> for TensorDyn {
290 fn from(t: Tensor<u64>) -> Self {
291 Self::U64(t)
292 }
293}
294
295impl From<Tensor<i64>> for TensorDyn {
296 fn from(t: Tensor<i64>) -> Self {
297 Self::I64(t)
298 }
299}
300
301impl From<Tensor<f16>> for TensorDyn {
302 fn from(t: Tensor<f16>) -> Self {
303 Self::F16(t)
304 }
305}
306
307impl From<Tensor<f32>> for TensorDyn {
308 fn from(t: Tensor<f32>) -> Self {
309 Self::F32(t)
310 }
311}
312
313impl From<Tensor<f64>> for TensorDyn {
314 fn from(t: Tensor<f64>) -> Self {
315 Self::F64(t)
316 }
317}
318
319impl fmt::Debug for TensorDyn {
320 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
321 dispatch!(self, fmt, f)
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328
329 #[test]
330 fn from_typed_tensor() {
331 let t = Tensor::<u8>::new(&[10], None, None).unwrap();
332 let dyn_t: TensorDyn = t.into();
333 assert_eq!(dyn_t.dtype(), DType::U8);
334 assert_eq!(dyn_t.shape(), &[10]);
335 }
336
337 #[test]
338 fn downcast_ref() {
339 let t = Tensor::<u8>::new(&[10], None, None).unwrap();
340 let dyn_t: TensorDyn = t.into();
341 assert!(dyn_t.as_u8().is_some());
342 assert!(dyn_t.as_i8().is_none());
343 }
344
345 #[test]
346 fn downcast_into() {
347 let t = Tensor::<u8>::new(&[10], None, None).unwrap();
348 let dyn_t: TensorDyn = t.into();
349 let back = dyn_t.into_u8().unwrap();
350 assert_eq!(back.shape(), &[10]);
351 }
352
353 #[test]
354 fn image_accessors() {
355 let t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
356 let dyn_t: TensorDyn = t.into();
357 assert_eq!(dyn_t.format(), Some(PixelFormat::Rgba));
358 assert_eq!(dyn_t.width(), Some(640));
359 assert_eq!(dyn_t.height(), Some(480));
360 assert!(!dyn_t.is_multiplane());
361 }
362
363 #[test]
364 fn image_constructor() {
365 let dyn_t = TensorDyn::image(640, 480, PixelFormat::Rgb, DType::U8, None).unwrap();
366 assert_eq!(dyn_t.dtype(), DType::U8);
367 assert_eq!(dyn_t.format(), Some(PixelFormat::Rgb));
368 assert_eq!(dyn_t.width(), Some(640));
369 }
370
371 #[test]
372 fn image_constructor_i8() {
373 let dyn_t = TensorDyn::image(640, 480, PixelFormat::Rgb, DType::I8, None).unwrap();
374 assert_eq!(dyn_t.dtype(), DType::I8);
375 assert_eq!(dyn_t.format(), Some(PixelFormat::Rgb));
376 }
377}