1use crate::{DType, PixelFormat, Tensor, TensorMemory, TensorTrait};
5use half::f16;
6use std::fmt;
7
8#[non_exhaustive]
10pub enum TensorDyn {
11 U8(Tensor<u8>),
13 I8(Tensor<i8>),
15 U16(Tensor<u16>),
17 I16(Tensor<i16>),
19 U32(Tensor<u32>),
21 I32(Tensor<i32>),
23 U64(Tensor<u64>),
25 I64(Tensor<i64>),
27 F16(Tensor<f16>),
29 F32(Tensor<f32>),
31 F64(Tensor<f64>),
33}
34
35macro_rules! dispatch {
37 ($self:expr, $method:ident $(, $arg:expr)*) => {
38 match $self {
39 TensorDyn::U8(t) => t.$method($($arg),*),
40 TensorDyn::I8(t) => t.$method($($arg),*),
41 TensorDyn::U16(t) => t.$method($($arg),*),
42 TensorDyn::I16(t) => t.$method($($arg),*),
43 TensorDyn::U32(t) => t.$method($($arg),*),
44 TensorDyn::I32(t) => t.$method($($arg),*),
45 TensorDyn::U64(t) => t.$method($($arg),*),
46 TensorDyn::I64(t) => t.$method($($arg),*),
47 TensorDyn::F16(t) => t.$method($($arg),*),
48 TensorDyn::F32(t) => t.$method($($arg),*),
49 TensorDyn::F64(t) => t.$method($($arg),*),
50 }
51 };
52}
53
54macro_rules! downcast_methods {
56 ($variant:ident, $ty:ty, $as_name:ident, $as_mut_name:ident, $into_name:ident) => {
57 pub fn $as_name(&self) -> Option<&Tensor<$ty>> {
59 match self {
60 Self::$variant(t) => Some(t),
61 _ => None,
62 }
63 }
64
65 pub fn $as_mut_name(&mut self) -> Option<&mut Tensor<$ty>> {
67 match self {
68 Self::$variant(t) => Some(t),
69 _ => None,
70 }
71 }
72
73 pub fn $into_name(self) -> Result<Tensor<$ty>, Self> {
75 match self {
76 Self::$variant(t) => Ok(t),
77 other => Err(other),
78 }
79 }
80 };
81}
82
83impl TensorDyn {
84 pub fn dtype(&self) -> DType {
86 match self {
87 Self::U8(_) => DType::U8,
88 Self::I8(_) => DType::I8,
89 Self::U16(_) => DType::U16,
90 Self::I16(_) => DType::I16,
91 Self::U32(_) => DType::U32,
92 Self::I32(_) => DType::I32,
93 Self::U64(_) => DType::U64,
94 Self::I64(_) => DType::I64,
95 Self::F16(_) => DType::F16,
96 Self::F32(_) => DType::F32,
97 Self::F64(_) => DType::F64,
98 }
99 }
100
101 pub fn shape(&self) -> &[usize] {
103 dispatch!(self, shape)
104 }
105
106 pub fn name(&self) -> String {
108 dispatch!(self, name)
109 }
110
111 pub fn format(&self) -> Option<PixelFormat> {
113 dispatch!(self, format)
114 }
115
116 pub fn width(&self) -> Option<usize> {
118 dispatch!(self, width)
119 }
120
121 pub fn height(&self) -> Option<usize> {
123 dispatch!(self, height)
124 }
125
126 pub fn size(&self) -> usize {
128 dispatch!(self, size)
129 }
130
131 pub fn memory(&self) -> TensorMemory {
133 dispatch!(self, memory)
134 }
135
136 pub fn reshape(&mut self, shape: &[usize]) -> crate::Result<()> {
138 dispatch!(self, reshape, shape)
139 }
140
141 pub fn set_format(&mut self, format: PixelFormat) -> crate::Result<()> {
159 dispatch!(self, set_format, format)
160 }
161
162 pub fn with_format(mut self, format: PixelFormat) -> crate::Result<Self> {
179 self.set_format(format)?;
180 Ok(self)
181 }
182
183 #[cfg(unix)]
185 pub fn clone_fd(&self) -> crate::Result<std::os::fd::OwnedFd> {
186 dispatch!(self, clone_fd)
187 }
188
189 #[cfg(target_os = "linux")]
200 pub fn dmabuf_clone(&self) -> crate::Result<std::os::fd::OwnedFd> {
201 if self.memory() != TensorMemory::Dma {
202 return Err(crate::Error::NotImplemented(format!(
203 "dmabuf_clone requires DMA-backed tensor, got {:?}",
204 self.memory()
205 )));
206 }
207 self.clone_fd()
208 }
209
210 #[cfg(target_os = "linux")]
221 pub fn dmabuf(&self) -> crate::Result<std::os::fd::BorrowedFd<'_>> {
222 dispatch!(self, dmabuf)
223 }
224
225 pub fn is_multiplane(&self) -> bool {
227 dispatch!(self, is_multiplane)
228 }
229
230 downcast_methods!(U8, u8, as_u8, as_u8_mut, into_u8);
233 downcast_methods!(I8, i8, as_i8, as_i8_mut, into_i8);
234 downcast_methods!(U16, u16, as_u16, as_u16_mut, into_u16);
235 downcast_methods!(I16, i16, as_i16, as_i16_mut, into_i16);
236 downcast_methods!(U32, u32, as_u32, as_u32_mut, into_u32);
237 downcast_methods!(I32, i32, as_i32, as_i32_mut, into_i32);
238 downcast_methods!(U64, u64, as_u64, as_u64_mut, into_u64);
239 downcast_methods!(I64, i64, as_i64, as_i64_mut, into_i64);
240 downcast_methods!(F16, f16, as_f16, as_f16_mut, into_f16);
241 downcast_methods!(F32, f32, as_f32, as_f32_mut, into_f32);
242 downcast_methods!(F64, f64, as_f64, as_f64_mut, into_f64);
243
244 pub fn new(
246 shape: &[usize],
247 dtype: DType,
248 memory: Option<TensorMemory>,
249 name: Option<&str>,
250 ) -> crate::Result<Self> {
251 match dtype {
252 DType::U8 => Tensor::<u8>::new(shape, memory, name).map(Self::U8),
253 DType::I8 => Tensor::<i8>::new(shape, memory, name).map(Self::I8),
254 DType::U16 => Tensor::<u16>::new(shape, memory, name).map(Self::U16),
255 DType::I16 => Tensor::<i16>::new(shape, memory, name).map(Self::I16),
256 DType::U32 => Tensor::<u32>::new(shape, memory, name).map(Self::U32),
257 DType::I32 => Tensor::<i32>::new(shape, memory, name).map(Self::I32),
258 DType::U64 => Tensor::<u64>::new(shape, memory, name).map(Self::U64),
259 DType::I64 => Tensor::<i64>::new(shape, memory, name).map(Self::I64),
260 DType::F16 => Tensor::<f16>::new(shape, memory, name).map(Self::F16),
261 DType::F32 => Tensor::<f32>::new(shape, memory, name).map(Self::F32),
262 DType::F64 => Tensor::<f64>::new(shape, memory, name).map(Self::F64),
263 }
264 }
265
266 #[cfg(unix)]
268 pub fn from_fd(
269 fd: std::os::fd::OwnedFd,
270 shape: &[usize],
271 dtype: DType,
272 name: Option<&str>,
273 ) -> crate::Result<Self> {
274 match dtype {
275 DType::U8 => Tensor::<u8>::from_fd(fd, shape, name).map(Self::U8),
276 DType::I8 => Tensor::<i8>::from_fd(fd, shape, name).map(Self::I8),
277 DType::U16 => Tensor::<u16>::from_fd(fd, shape, name).map(Self::U16),
278 DType::I16 => Tensor::<i16>::from_fd(fd, shape, name).map(Self::I16),
279 DType::U32 => Tensor::<u32>::from_fd(fd, shape, name).map(Self::U32),
280 DType::I32 => Tensor::<i32>::from_fd(fd, shape, name).map(Self::I32),
281 DType::U64 => Tensor::<u64>::from_fd(fd, shape, name).map(Self::U64),
282 DType::I64 => Tensor::<i64>::from_fd(fd, shape, name).map(Self::I64),
283 DType::F16 => Tensor::<f16>::from_fd(fd, shape, name).map(Self::F16),
284 DType::F32 => Tensor::<f32>::from_fd(fd, shape, name).map(Self::F32),
285 DType::F64 => Tensor::<f64>::from_fd(fd, shape, name).map(Self::F64),
286 }
287 }
288
289 pub fn image(
307 width: usize,
308 height: usize,
309 format: PixelFormat,
310 dtype: DType,
311 memory: Option<TensorMemory>,
312 ) -> crate::Result<Self> {
313 match dtype {
314 DType::U8 => Tensor::<u8>::image(width, height, format, memory).map(Self::U8),
315 DType::I8 => Tensor::<i8>::image(width, height, format, memory).map(Self::I8),
316 DType::U16 => Tensor::<u16>::image(width, height, format, memory).map(Self::U16),
317 DType::I16 => Tensor::<i16>::image(width, height, format, memory).map(Self::I16),
318 DType::U32 => Tensor::<u32>::image(width, height, format, memory).map(Self::U32),
319 DType::I32 => Tensor::<i32>::image(width, height, format, memory).map(Self::I32),
320 DType::U64 => Tensor::<u64>::image(width, height, format, memory).map(Self::U64),
321 DType::I64 => Tensor::<i64>::image(width, height, format, memory).map(Self::I64),
322 DType::F16 => Tensor::<f16>::image(width, height, format, memory).map(Self::F16),
323 DType::F32 => Tensor::<f32>::image(width, height, format, memory).map(Self::F32),
324 DType::F64 => Tensor::<f64>::image(width, height, format, memory).map(Self::F64),
325 }
326 }
327}
328
329impl From<Tensor<u8>> for TensorDyn {
332 fn from(t: Tensor<u8>) -> Self {
333 Self::U8(t)
334 }
335}
336
337impl From<Tensor<i8>> for TensorDyn {
338 fn from(t: Tensor<i8>) -> Self {
339 Self::I8(t)
340 }
341}
342
343impl From<Tensor<u16>> for TensorDyn {
344 fn from(t: Tensor<u16>) -> Self {
345 Self::U16(t)
346 }
347}
348
349impl From<Tensor<i16>> for TensorDyn {
350 fn from(t: Tensor<i16>) -> Self {
351 Self::I16(t)
352 }
353}
354
355impl From<Tensor<u32>> for TensorDyn {
356 fn from(t: Tensor<u32>) -> Self {
357 Self::U32(t)
358 }
359}
360
361impl From<Tensor<i32>> for TensorDyn {
362 fn from(t: Tensor<i32>) -> Self {
363 Self::I32(t)
364 }
365}
366
367impl From<Tensor<u64>> for TensorDyn {
368 fn from(t: Tensor<u64>) -> Self {
369 Self::U64(t)
370 }
371}
372
373impl From<Tensor<i64>> for TensorDyn {
374 fn from(t: Tensor<i64>) -> Self {
375 Self::I64(t)
376 }
377}
378
379impl From<Tensor<f16>> for TensorDyn {
380 fn from(t: Tensor<f16>) -> Self {
381 Self::F16(t)
382 }
383}
384
385impl From<Tensor<f32>> for TensorDyn {
386 fn from(t: Tensor<f32>) -> Self {
387 Self::F32(t)
388 }
389}
390
391impl From<Tensor<f64>> for TensorDyn {
392 fn from(t: Tensor<f64>) -> Self {
393 Self::F64(t)
394 }
395}
396
397impl fmt::Debug for TensorDyn {
398 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
399 dispatch!(self, fmt, f)
400 }
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406
407 #[test]
408 fn from_typed_tensor() {
409 let t = Tensor::<u8>::new(&[10], None, None).unwrap();
410 let dyn_t: TensorDyn = t.into();
411 assert_eq!(dyn_t.dtype(), DType::U8);
412 assert_eq!(dyn_t.shape(), &[10]);
413 }
414
415 #[test]
416 fn downcast_ref() {
417 let t = Tensor::<u8>::new(&[10], None, None).unwrap();
418 let dyn_t: TensorDyn = t.into();
419 assert!(dyn_t.as_u8().is_some());
420 assert!(dyn_t.as_i8().is_none());
421 }
422
423 #[test]
424 fn downcast_into() {
425 let t = Tensor::<u8>::new(&[10], None, None).unwrap();
426 let dyn_t: TensorDyn = t.into();
427 let back = dyn_t.into_u8().unwrap();
428 assert_eq!(back.shape(), &[10]);
429 }
430
431 #[test]
432 fn image_accessors() {
433 let t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
434 let dyn_t: TensorDyn = t.into();
435 assert_eq!(dyn_t.format(), Some(PixelFormat::Rgba));
436 assert_eq!(dyn_t.width(), Some(640));
437 assert_eq!(dyn_t.height(), Some(480));
438 assert!(!dyn_t.is_multiplane());
439 }
440
441 #[test]
442 fn image_constructor() {
443 let dyn_t = TensorDyn::image(640, 480, PixelFormat::Rgb, DType::U8, None).unwrap();
444 assert_eq!(dyn_t.dtype(), DType::U8);
445 assert_eq!(dyn_t.format(), Some(PixelFormat::Rgb));
446 assert_eq!(dyn_t.width(), Some(640));
447 }
448
449 #[test]
450 fn image_constructor_i8() {
451 let dyn_t = TensorDyn::image(640, 480, PixelFormat::Rgb, DType::I8, None).unwrap();
452 assert_eq!(dyn_t.dtype(), DType::I8);
453 assert_eq!(dyn_t.format(), Some(PixelFormat::Rgb));
454 }
455
456 #[test]
457 fn set_format_packed() {
458 let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
459 assert_eq!(t.format(), None);
460 t.set_format(PixelFormat::Rgb).unwrap();
461 assert_eq!(t.format(), Some(PixelFormat::Rgb));
462 assert_eq!(t.width(), Some(640));
463 assert_eq!(t.height(), Some(480));
464 }
465
466 #[test]
467 fn set_format_planar() {
468 let mut t = TensorDyn::new(&[3, 480, 640], DType::U8, None, None).unwrap();
469 t.set_format(PixelFormat::PlanarRgb).unwrap();
470 assert_eq!(t.format(), Some(PixelFormat::PlanarRgb));
471 assert_eq!(t.width(), Some(640));
472 assert_eq!(t.height(), Some(480));
473 }
474
475 #[test]
476 fn set_format_rejects_wrong_shape() {
477 let mut t = TensorDyn::new(&[480, 640, 4], DType::U8, None, None).unwrap();
478 assert!(t.set_format(PixelFormat::Rgb).is_err());
479 }
480
481 #[test]
482 fn with_format_builder() {
483 let t = TensorDyn::new(&[480, 640, 4], DType::U8, None, None)
484 .unwrap()
485 .with_format(PixelFormat::Rgba)
486 .unwrap();
487 assert_eq!(t.format(), Some(PixelFormat::Rgba));
488 assert_eq!(t.width(), Some(640));
489 assert_eq!(t.height(), Some(480));
490 }
491
492 #[cfg(target_os = "linux")]
493 #[test]
494 fn dmabuf_clone_mem_tensor_fails() {
495 let t = TensorDyn::new(&[480, 640, 3], DType::U8, Some(TensorMemory::Mem), None).unwrap();
496 assert_eq!(t.memory(), TensorMemory::Mem);
497 assert!(t.dmabuf_clone().is_err());
498 }
499
500 #[cfg(target_os = "linux")]
501 #[test]
502 fn dmabuf_mem_tensor_fails() {
503 let t = TensorDyn::new(&[480, 640, 3], DType::U8, Some(TensorMemory::Mem), None).unwrap();
504 assert!(t.dmabuf().is_err());
505 }
506
507 #[test]
508 fn set_format_semi_planar_nv12() {
509 let mut t = TensorDyn::new(&[720, 640], DType::U8, Some(TensorMemory::Mem), None).unwrap();
511 t.set_format(PixelFormat::Nv12).unwrap();
512 assert_eq!(t.format(), Some(PixelFormat::Nv12));
513 assert_eq!(t.width(), Some(640));
514 assert_eq!(t.height(), Some(480));
515 }
516
517 #[test]
518 fn set_format_semi_planar_nv16() {
519 let mut t = TensorDyn::new(&[960, 640], DType::U8, Some(TensorMemory::Mem), None).unwrap();
521 t.set_format(PixelFormat::Nv16).unwrap();
522 assert_eq!(t.format(), Some(PixelFormat::Nv16));
523 assert_eq!(t.width(), Some(640));
524 assert_eq!(t.height(), Some(480));
525 }
526
527 #[test]
528 fn with_format_rejects_wrong_shape() {
529 let result = TensorDyn::new(&[480, 640, 4], DType::U8, None, None)
530 .unwrap()
531 .with_format(PixelFormat::Rgb);
532 assert!(result.is_err());
533 }
534
535 #[test]
536 fn set_format_preserved_after_rejection() {
537 let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
538 t.set_format(PixelFormat::Rgb).unwrap();
539 assert_eq!(t.format(), Some(PixelFormat::Rgb));
540
541 assert!(t.set_format(PixelFormat::Rgba).is_err());
543
544 assert_eq!(t.format(), Some(PixelFormat::Rgb));
546 }
547
548 #[test]
549 fn set_format_idempotent() {
550 let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
551 t.set_format(PixelFormat::Rgb).unwrap();
552 t.set_format(PixelFormat::Rgb).unwrap();
553 assert_eq!(t.format(), Some(PixelFormat::Rgb));
554 assert_eq!(t.width(), Some(640));
555 assert_eq!(t.height(), Some(480));
556 }
557}