Skip to main content

openinfer_simulator/tensor/
value.rs

1use anyhow::{anyhow, Result};
2use serde::{Deserialize, Serialize};
3
4use super::{
5    numel, Bitset, BF16, F16, F8, I1, I2, I4, T1, T2, U1, U2, U4, Tensor, TensorOptions,
6};
7
8/// Element type that can be converted to/from `TensorValue`.
9pub trait TensorElement: Sized + Clone {
10    /// Attempt to extract a typed tensor from a generic value.
11    fn from_value(value: &TensorValue) -> Option<Tensor<Self>>;
12    /// Wrap a typed tensor into a generic value.
13    fn into_value(tensor: Tensor<Self>) -> TensorValue;
14}
15
16impl<T> From<Vec<T>> for Tensor<T> {
17    fn from(value: Vec<T>) -> Self {
18        Tensor::new(value)
19    }
20}
21
22impl TensorElement for f32 {
23    fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
24        match value {
25            TensorValue::F32(tensor) => Some(tensor.clone()),
26            _ => None,
27        }
28    }
29
30    fn into_value(tensor: Tensor<Self>) -> TensorValue {
31        TensorValue::F32(tensor)
32    }
33}
34
35impl TensorElement for f64 {
36    fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
37        match value {
38            TensorValue::F64(tensor) => Some(tensor.clone()),
39            _ => None,
40        }
41    }
42
43    fn into_value(tensor: Tensor<Self>) -> TensorValue {
44        TensorValue::F64(tensor)
45    }
46}
47
48impl TensorElement for i8 {
49    fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
50        match value {
51            TensorValue::I8(tensor) => Some(tensor.clone()),
52            _ => None,
53        }
54    }
55
56    fn into_value(tensor: Tensor<Self>) -> TensorValue {
57        TensorValue::I8(tensor)
58    }
59}
60
61impl TensorElement for i16 {
62    fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
63        match value {
64            TensorValue::I16(tensor) => Some(tensor.clone()),
65            _ => None,
66        }
67    }
68
69    fn into_value(tensor: Tensor<Self>) -> TensorValue {
70        TensorValue::I16(tensor)
71    }
72}
73
74impl TensorElement for i32 {
75    fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
76        match value {
77            TensorValue::I32(tensor) => Some(tensor.clone()),
78            _ => None,
79        }
80    }
81
82    fn into_value(tensor: Tensor<Self>) -> TensorValue {
83        TensorValue::I32(tensor)
84    }
85}
86
87impl TensorElement for i64 {
88    fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
89        match value {
90            TensorValue::I64(tensor) => Some(tensor.clone()),
91            _ => None,
92        }
93    }
94
95    fn into_value(tensor: Tensor<Self>) -> TensorValue {
96        TensorValue::I64(tensor)
97    }
98}
99
100impl TensorElement for u8 {
101    fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
102        match value {
103            TensorValue::U8(tensor) => Some(tensor.clone()),
104            _ => None,
105        }
106    }
107
108    fn into_value(tensor: Tensor<Self>) -> TensorValue {
109        TensorValue::U8(tensor)
110    }
111}
112
113impl TensorElement for u16 {
114    fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
115        match value {
116            TensorValue::U16(tensor) => Some(tensor.clone()),
117            _ => None,
118        }
119    }
120
121    fn into_value(tensor: Tensor<Self>) -> TensorValue {
122        TensorValue::U16(tensor)
123    }
124}
125
126impl TensorElement for u32 {
127    fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
128        match value {
129            TensorValue::U32(tensor) => Some(tensor.clone()),
130            _ => None,
131        }
132    }
133
134    fn into_value(tensor: Tensor<Self>) -> TensorValue {
135        TensorValue::U32(tensor)
136    }
137}
138
139impl TensorElement for u64 {
140    fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
141        match value {
142            TensorValue::U64(tensor) => Some(tensor.clone()),
143            _ => None,
144        }
145    }
146
147    fn into_value(tensor: Tensor<Self>) -> TensorValue {
148        TensorValue::U64(tensor)
149    }
150}
151
152impl TensorElement for bool {
153    fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
154        match value {
155            TensorValue::Bool(tensor) => Some(tensor.clone()),
156            _ => None,
157        }
158    }
159
160    fn into_value(tensor: Tensor<Self>) -> TensorValue {
161        TensorValue::Bool(tensor)
162    }
163}
164
165impl TensorElement for F16 {
166    fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
167        match value {
168            TensorValue::F16(tensor) => Some(tensor.clone()),
169            _ => None,
170        }
171    }
172
173    fn into_value(tensor: Tensor<Self>) -> TensorValue {
174        TensorValue::F16(tensor)
175    }
176}
177
178impl TensorElement for BF16 {
179    fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
180        match value {
181            TensorValue::BF16(tensor) => Some(tensor.clone()),
182            _ => None,
183        }
184    }
185
186    fn into_value(tensor: Tensor<Self>) -> TensorValue {
187        TensorValue::BF16(tensor)
188    }
189}
190
191impl TensorElement for F8 {
192    fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
193        match value {
194            TensorValue::F8(tensor) => Some(tensor.clone()),
195            _ => None,
196        }
197    }
198
199    fn into_value(tensor: Tensor<Self>) -> TensorValue {
200        TensorValue::F8(tensor)
201    }
202}
203
204impl TensorElement for I4 {
205    fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
206        match value {
207            TensorValue::I4(tensor) => Some(tensor.clone()),
208            _ => None,
209        }
210    }
211
212    fn into_value(tensor: Tensor<Self>) -> TensorValue {
213        TensorValue::I4(tensor)
214    }
215}
216
217impl TensorElement for I2 {
218    fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
219        match value {
220            TensorValue::I2(tensor) => Some(tensor.clone()),
221            _ => None,
222        }
223    }
224
225    fn into_value(tensor: Tensor<Self>) -> TensorValue {
226        TensorValue::I2(tensor)
227    }
228}
229
230impl TensorElement for I1 {
231    fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
232        match value {
233            TensorValue::I1(tensor) => Some(tensor.clone()),
234            _ => None,
235        }
236    }
237
238    fn into_value(tensor: Tensor<Self>) -> TensorValue {
239        TensorValue::I1(tensor)
240    }
241}
242
243impl TensorElement for U4 {
244    fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
245        match value {
246            TensorValue::U4(tensor) => Some(tensor.clone()),
247            _ => None,
248        }
249    }
250
251    fn into_value(tensor: Tensor<Self>) -> TensorValue {
252        TensorValue::U4(tensor)
253    }
254}
255
256impl TensorElement for U2 {
257    fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
258        match value {
259            TensorValue::U2(tensor) => Some(tensor.clone()),
260            _ => None,
261        }
262    }
263
264    fn into_value(tensor: Tensor<Self>) -> TensorValue {
265        TensorValue::U2(tensor)
266    }
267}
268
269impl TensorElement for U1 {
270    fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
271        match value {
272            TensorValue::U1(tensor) => Some(tensor.clone()),
273            _ => None,
274        }
275    }
276
277    fn into_value(tensor: Tensor<Self>) -> TensorValue {
278        TensorValue::U1(tensor)
279    }
280}
281
282impl TensorElement for T2 {
283    fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
284        match value {
285            TensorValue::T2(tensor) => Some(tensor.clone()),
286            _ => None,
287        }
288    }
289
290    fn into_value(tensor: Tensor<Self>) -> TensorValue {
291        TensorValue::T2(tensor)
292    }
293}
294
295impl TensorElement for T1 {
296    fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
297        match value {
298            TensorValue::T1(tensor) => Some(tensor.clone()),
299            _ => None,
300        }
301    }
302
303    fn into_value(tensor: Tensor<Self>) -> TensorValue {
304        TensorValue::T1(tensor)
305    }
306}
307
308impl TensorElement for Bitset {
309    fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
310        match value {
311            TensorValue::Bitset(tensor) => Some(tensor.clone()),
312            _ => None,
313        }
314    }
315
316    fn into_value(tensor: Tensor<Self>) -> TensorValue {
317        TensorValue::Bitset(tensor)
318    }
319}
320
321/// Supported element dtypes.
322#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
323pub enum DType {
324    I8,
325    I16,
326    F32,
327    F64,
328    U8,
329    U16,
330    I32,
331    I64,
332    U32,
333    U64,
334    Bool,
335    Bitset,
336    F16,
337    BF16,
338    F8,
339    I4,
340    I2,
341    I1,
342    U4,
343    U2,
344    U1,
345    T2,
346    T1,
347}
348
349impl DType {
350    /// Parse a dtype from its identifier string.
351    pub fn from_ident(ident: &str) -> Result<Self> {
352        match ident {
353            "i8" => Ok(DType::I8),
354            "i16" => Ok(DType::I16),
355            "f32" => Ok(DType::F32),
356            "f64" => Ok(DType::F64),
357            "u8" => Ok(DType::U8),
358            "u16" => Ok(DType::U16),
359            "i32" => Ok(DType::I32),
360            "i64" => Ok(DType::I64),
361            "u32" => Ok(DType::U32),
362            "u64" => Ok(DType::U64),
363            "bool" => Ok(DType::Bool),
364            "bitset" => Ok(DType::Bitset),
365            "f16" => Ok(DType::F16),
366            "bf16" => Ok(DType::BF16),
367            "f8" | "f8e5m2" | "float8e5m2" => Ok(DType::F8),
368            "i4" => Ok(DType::I4),
369            "i2" => Ok(DType::I2),
370            "i1" => Ok(DType::I1),
371            "u4" => Ok(DType::U4),
372            "u2" => Ok(DType::U2),
373            "u1" => Ok(DType::U1),
374            "t2" => Ok(DType::T2),
375            "t1" => Ok(DType::T1),
376            _ => Err(anyhow!("unsupported dtype: {}", ident)),
377        }
378    }
379
380    /// True if the dtype is supported across all backends.
381    pub fn is_universal(self) -> bool {
382        matches!(
383            self,
384            DType::F64
385                | DType::F32
386                | DType::I64
387                | DType::I32
388                | DType::I16
389                | DType::I8
390                | DType::U64
391                | DType::U32
392                | DType::U16
393                | DType::U8
394                | DType::Bool
395        )
396    }
397
398    /// True if the dtype is packed (bit-level).
399    pub fn is_packed(self) -> bool {
400        matches!(
401            self,
402            DType::I1
403                | DType::I2
404                | DType::I4
405                | DType::U1
406                | DType::U2
407                | DType::U4
408                | DType::T1
409                | DType::T2
410        )
411    }
412
413    /// True if the dtype is a floating-point type.
414    pub fn is_float(self) -> bool {
415        matches!(self, DType::F8 | DType::F16 | DType::BF16 | DType::F32 | DType::F64)
416    }
417
418    /// True if the dtype is a signed integer type.
419    pub fn is_signed_int(self) -> bool {
420        matches!(self, DType::I8 | DType::I16 | DType::I32 | DType::I64)
421    }
422
423    /// True if the dtype is a packed signed integer type.
424    pub fn is_packed_signed(self) -> bool {
425        matches!(self, DType::I1 | DType::I2 | DType::I4)
426    }
427
428    /// Bit width of a single logical element.
429    pub fn bit_width(self) -> u8 {
430        match self {
431            DType::I1 => 1,
432            DType::I2 => 2,
433            DType::I4 => 4,
434            DType::U1 => 1,
435            DType::U2 => 2,
436            DType::U4 => 4,
437            DType::T1 => 1,
438            DType::T2 => 2,
439            DType::I8 | DType::U8 | DType::Bool => 8,
440            DType::I16 | DType::U16 | DType::F16 | DType::BF16 => 16,
441            DType::I32 | DType::U32 | DType::F32 => 32,
442            DType::I64 | DType::U64 | DType::F64 => 64,
443            DType::F8 => 8,
444            DType::Bitset => 8,
445        }
446    }
447
448    /// Storage length in elements for a logical length.
449    pub fn storage_len(self, logical_len: usize) -> usize {
450        if self.is_packed() {
451            let bits = logical_len.saturating_mul(self.bit_width() as usize);
452            (bits + 7) / 8
453        } else {
454            logical_len
455        }
456    }
457}
458
459/// Runtime tensor value with an enum over concrete dtypes.
460#[derive(Debug, Clone)]
461pub enum TensorValue {
462    I8(Tensor<i8>),
463    I16(Tensor<i16>),
464    F32(Tensor<f32>),
465    F64(Tensor<f64>),
466    U8(Tensor<u8>),
467    U16(Tensor<u16>),
468    I32(Tensor<i32>),
469    I64(Tensor<i64>),
470    U32(Tensor<u32>),
471    U64(Tensor<u64>),
472    Bool(Tensor<bool>),
473    Bitset(Tensor<Bitset>),
474    F16(Tensor<F16>),
475    BF16(Tensor<BF16>),
476    F8(Tensor<F8>),
477    I4(Tensor<I4>),
478    I2(Tensor<I2>),
479    I1(Tensor<I1>),
480    U4(Tensor<U4>),
481    U2(Tensor<U2>),
482    U1(Tensor<U1>),
483    T2(Tensor<T2>),
484    T1(Tensor<T1>),
485}
486
487// TensorValue is moved across threads but not shared concurrently.
488unsafe impl Send for TensorValue {}
489
490impl TensorValue {
491    /// Return the dtype of this value.
492    pub fn dtype(&self) -> DType {
493        match self {
494            TensorValue::I8(_) => DType::I8,
495            TensorValue::I16(_) => DType::I16,
496            TensorValue::F32(_) => DType::F32,
497            TensorValue::F64(_) => DType::F64,
498            TensorValue::U8(_) => DType::U8,
499            TensorValue::U16(_) => DType::U16,
500            TensorValue::I32(_) => DType::I32,
501            TensorValue::I64(_) => DType::I64,
502            TensorValue::U32(_) => DType::U32,
503            TensorValue::U64(_) => DType::U64,
504            TensorValue::Bool(_) => DType::Bool,
505            TensorValue::Bitset(_) => DType::Bitset,
506            TensorValue::F16(_) => DType::F16,
507            TensorValue::BF16(_) => DType::BF16,
508            TensorValue::F8(_) => DType::F8,
509            TensorValue::I4(_) => DType::I4,
510            TensorValue::I2(_) => DType::I2,
511            TensorValue::I1(_) => DType::I1,
512            TensorValue::U4(_) => DType::U4,
513            TensorValue::U2(_) => DType::U2,
514            TensorValue::U1(_) => DType::U1,
515            TensorValue::T2(_) => DType::T2,
516            TensorValue::T1(_) => DType::T1,
517        }
518    }
519
520    /// Return the logical element count.
521    pub fn len(&self) -> usize {
522        numel(self.shape())
523    }
524
525    /// Return the tensor shape.
526    pub fn shape(&self) -> &[usize] {
527        match self {
528            TensorValue::I8(tensor) => tensor.shape(),
529            TensorValue::I16(tensor) => tensor.shape(),
530            TensorValue::F32(tensor) => tensor.shape(),
531            TensorValue::F64(tensor) => tensor.shape(),
532            TensorValue::U8(tensor) => tensor.shape(),
533            TensorValue::U16(tensor) => tensor.shape(),
534            TensorValue::I32(tensor) => tensor.shape(),
535            TensorValue::I64(tensor) => tensor.shape(),
536            TensorValue::U32(tensor) => tensor.shape(),
537            TensorValue::U64(tensor) => tensor.shape(),
538            TensorValue::Bool(tensor) => tensor.shape(),
539            TensorValue::Bitset(tensor) => tensor.shape(),
540            TensorValue::F16(tensor) => tensor.shape(),
541            TensorValue::BF16(tensor) => tensor.shape(),
542            TensorValue::F8(tensor) => tensor.shape(),
543            TensorValue::I4(tensor) => tensor.shape(),
544            TensorValue::I2(tensor) => tensor.shape(),
545            TensorValue::I1(tensor) => tensor.shape(),
546            TensorValue::U4(tensor) => tensor.shape(),
547            TensorValue::U2(tensor) => tensor.shape(),
548            TensorValue::U1(tensor) => tensor.shape(),
549            TensorValue::T2(tensor) => tensor.shape(),
550            TensorValue::T1(tensor) => tensor.shape(),
551        }
552    }
553
554    /// Return the tensor strides.
555    pub fn strides(&self) -> &[usize] {
556        match self {
557            TensorValue::I8(tensor) => tensor.strides(),
558            TensorValue::I16(tensor) => tensor.strides(),
559            TensorValue::F32(tensor) => tensor.strides(),
560            TensorValue::F64(tensor) => tensor.strides(),
561            TensorValue::U8(tensor) => tensor.strides(),
562            TensorValue::U16(tensor) => tensor.strides(),
563            TensorValue::I32(tensor) => tensor.strides(),
564            TensorValue::I64(tensor) => tensor.strides(),
565            TensorValue::U32(tensor) => tensor.strides(),
566            TensorValue::U64(tensor) => tensor.strides(),
567            TensorValue::Bool(tensor) => tensor.strides(),
568            TensorValue::Bitset(tensor) => tensor.strides(),
569            TensorValue::F16(tensor) => tensor.strides(),
570            TensorValue::BF16(tensor) => tensor.strides(),
571            TensorValue::F8(tensor) => tensor.strides(),
572            TensorValue::I4(tensor) => tensor.strides(),
573            TensorValue::I2(tensor) => tensor.strides(),
574            TensorValue::I1(tensor) => tensor.strides(),
575            TensorValue::U4(tensor) => tensor.strides(),
576            TensorValue::U2(tensor) => tensor.strides(),
577            TensorValue::U1(tensor) => tensor.strides(),
578            TensorValue::T2(tensor) => tensor.strides(),
579            TensorValue::T1(tensor) => tensor.strides(),
580        }
581    }
582
583    /// Construct a zero-filled tensor for a dtype and shape.
584    pub fn zeros(dtype: DType, shape: &[usize]) -> Self {
585        let len = numel(shape);
586        let packed_len = dtype.storage_len(len);
587        match dtype {
588            DType::I8 => TensorValue::I8(
589                Tensor::from_vec_with_opts(vec![0; len], TensorOptions {
590                    shape: Some(shape.to_vec()),
591                    ..TensorOptions::default()
592                })
593                .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
594            ),
595            DType::I16 => TensorValue::I16(
596                Tensor::from_vec_with_opts(vec![0; len], TensorOptions {
597                    shape: Some(shape.to_vec()),
598                    ..TensorOptions::default()
599                })
600                .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
601            ),
602            DType::F32 => TensorValue::F32(
603                Tensor::from_vec_with_opts(vec![0.0; len], TensorOptions {
604                    shape: Some(shape.to_vec()),
605                    ..TensorOptions::default()
606                })
607                .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
608            ),
609            DType::F64 => TensorValue::F64(
610                Tensor::from_vec_with_opts(vec![0.0; len], TensorOptions {
611                    shape: Some(shape.to_vec()),
612                    ..TensorOptions::default()
613                })
614                .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
615            ),
616            DType::U8 => TensorValue::U8(
617                Tensor::from_vec_with_opts(vec![0; len], TensorOptions {
618                    shape: Some(shape.to_vec()),
619                    ..TensorOptions::default()
620                })
621                .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
622            ),
623            DType::U16 => TensorValue::U16(
624                Tensor::from_vec_with_opts(vec![0; len], TensorOptions {
625                    shape: Some(shape.to_vec()),
626                    ..TensorOptions::default()
627                })
628                .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
629            ),
630            DType::I32 => TensorValue::I32(
631                Tensor::from_vec_with_opts(vec![0; len], TensorOptions {
632                    shape: Some(shape.to_vec()),
633                    ..TensorOptions::default()
634                })
635                .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
636            ),
637            DType::I64 => TensorValue::I64(
638                Tensor::from_vec_with_opts(vec![0; len], TensorOptions {
639                    shape: Some(shape.to_vec()),
640                    ..TensorOptions::default()
641                })
642                .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
643            ),
644            DType::U32 => TensorValue::U32(
645                Tensor::from_vec_with_opts(vec![0; len], TensorOptions {
646                    shape: Some(shape.to_vec()),
647                    ..TensorOptions::default()
648                })
649                .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
650            ),
651            DType::U64 => TensorValue::U64(
652                Tensor::from_vec_with_opts(vec![0; len], TensorOptions {
653                    shape: Some(shape.to_vec()),
654                    ..TensorOptions::default()
655                })
656                .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
657            ),
658            DType::Bool => TensorValue::Bool(
659                Tensor::from_vec_with_opts(vec![false; len], TensorOptions {
660                    shape: Some(shape.to_vec()),
661                    ..TensorOptions::default()
662                })
663                .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
664            ),
665            DType::Bitset => TensorValue::Bitset(
666                Tensor::from_vec_with_opts(vec![Bitset { bits: 0 }; len], TensorOptions {
667                    shape: Some(shape.to_vec()),
668                    ..TensorOptions::default()
669                })
670                .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
671            ),
672            DType::F16 => TensorValue::F16(
673                Tensor::from_vec_with_opts(vec![F16 { bits: 0 }; len], TensorOptions {
674                    shape: Some(shape.to_vec()),
675                    ..TensorOptions::default()
676                })
677                .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
678            ),
679            DType::BF16 => TensorValue::BF16(
680                Tensor::from_vec_with_opts(vec![BF16 { bits: 0 }; len], TensorOptions {
681                    shape: Some(shape.to_vec()),
682                    ..TensorOptions::default()
683                })
684                .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
685            ),
686            DType::F8 => TensorValue::F8(
687                Tensor::from_vec_with_opts(vec![F8 { bits: 0 }; len], TensorOptions {
688                    shape: Some(shape.to_vec()),
689                    ..TensorOptions::default()
690                })
691                .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
692            ),
693            DType::I4 => TensorValue::I4(
694                Tensor::from_vec_with_opts(vec![I4 { bits: 0 }; packed_len], TensorOptions {
695                    shape: Some(shape.to_vec()),
696                    allow_len_mismatch: true,
697                    ..TensorOptions::default()
698                })
699                .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
700            ),
701            DType::I2 => TensorValue::I2(
702                Tensor::from_vec_with_opts(vec![I2 { bits: 0 }; packed_len], TensorOptions {
703                    shape: Some(shape.to_vec()),
704                    allow_len_mismatch: true,
705                    ..TensorOptions::default()
706                })
707                .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
708            ),
709            DType::I1 => TensorValue::I1(
710                Tensor::from_vec_with_opts(vec![I1 { bits: 0 }; packed_len], TensorOptions {
711                    shape: Some(shape.to_vec()),
712                    allow_len_mismatch: true,
713                    ..TensorOptions::default()
714                })
715                .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
716            ),
717            DType::U4 => TensorValue::U4(
718                Tensor::from_vec_with_opts(vec![U4 { bits: 0 }; packed_len], TensorOptions {
719                    shape: Some(shape.to_vec()),
720                    allow_len_mismatch: true,
721                    ..TensorOptions::default()
722                })
723                .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
724            ),
725            DType::U2 => TensorValue::U2(
726                Tensor::from_vec_with_opts(vec![U2 { bits: 0 }; packed_len], TensorOptions {
727                    shape: Some(shape.to_vec()),
728                    allow_len_mismatch: true,
729                    ..TensorOptions::default()
730                })
731                .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
732            ),
733            DType::U1 => TensorValue::U1(
734                Tensor::from_vec_with_opts(vec![U1 { bits: 0 }; packed_len], TensorOptions {
735                    shape: Some(shape.to_vec()),
736                    allow_len_mismatch: true,
737                    ..TensorOptions::default()
738                })
739                .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
740            ),
741            DType::T2 => TensorValue::T2(
742                Tensor::from_vec_with_opts(vec![T2 { bits: 0 }; packed_len], TensorOptions {
743                    shape: Some(shape.to_vec()),
744                    allow_len_mismatch: true,
745                    ..TensorOptions::default()
746                })
747                .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
748            ),
749            DType::T1 => TensorValue::T1(
750                Tensor::from_vec_with_opts(vec![T1 { bits: 0 }; packed_len], TensorOptions {
751                    shape: Some(shape.to_vec()),
752                    allow_len_mismatch: true,
753                    ..TensorOptions::default()
754                })
755                .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
756            ),
757        }
758    }
759
760    /// Borrow as an i8 tensor.
761    pub fn as_i8(&self) -> Result<&Tensor<i8>> {
762        match self {
763            TensorValue::I8(tensor) => Ok(tensor),
764            _ => Err(anyhow!("expected i8 tensor")),
765        }
766    }
767
768    /// Borrow as an i16 tensor.
769    pub fn as_i16(&self) -> Result<&Tensor<i16>> {
770        match self {
771            TensorValue::I16(tensor) => Ok(tensor),
772            _ => Err(anyhow!("expected i16 tensor")),
773        }
774    }
775
776    /// Borrow as an f32 tensor.
777    pub fn as_f32(&self) -> Result<&Tensor<f32>> {
778        match self {
779            TensorValue::F32(tensor) => Ok(tensor),
780            _ => Err(anyhow!("expected f32 tensor")),
781        }
782    }
783
784    /// Borrow as an f64 tensor.
785    pub fn as_f64(&self) -> Result<&Tensor<f64>> {
786        match self {
787            TensorValue::F64(tensor) => Ok(tensor),
788            _ => Err(anyhow!("expected f64 tensor")),
789        }
790    }
791
792    /// Borrow as a u8 tensor.
793    pub fn as_u8(&self) -> Result<&Tensor<u8>> {
794        match self {
795            TensorValue::U8(tensor) => Ok(tensor),
796            _ => Err(anyhow!("expected u8 tensor")),
797        }
798    }
799
800    /// Borrow as a u16 tensor.
801    pub fn as_u16(&self) -> Result<&Tensor<u16>> {
802        match self {
803            TensorValue::U16(tensor) => Ok(tensor),
804            _ => Err(anyhow!("expected u16 tensor")),
805        }
806    }
807
808    /// Borrow as an i32 tensor.
809    pub fn as_i32(&self) -> Result<&Tensor<i32>> {
810        match self {
811            TensorValue::I32(tensor) => Ok(tensor),
812            _ => Err(anyhow!("expected i32 tensor")),
813        }
814    }
815
816    /// Borrow as an i64 tensor.
817    pub fn as_i64(&self) -> Result<&Tensor<i64>> {
818        match self {
819            TensorValue::I64(tensor) => Ok(tensor),
820            _ => Err(anyhow!("expected i64 tensor")),
821        }
822    }
823
824    /// Borrow as a u32 tensor.
825    pub fn as_u32(&self) -> Result<&Tensor<u32>> {
826        match self {
827            TensorValue::U32(tensor) => Ok(tensor),
828            _ => Err(anyhow!("expected u32 tensor")),
829        }
830    }
831
832    /// Borrow as a u64 tensor.
833    pub fn as_u64(&self) -> Result<&Tensor<u64>> {
834        match self {
835            TensorValue::U64(tensor) => Ok(tensor),
836            _ => Err(anyhow!("expected u64 tensor")),
837        }
838    }
839
840    /// Borrow as a bool tensor.
841    pub fn as_bool(&self) -> Result<&Tensor<bool>> {
842        match self {
843            TensorValue::Bool(tensor) => Ok(tensor),
844            _ => Err(anyhow!("expected bool tensor")),
845        }
846    }
847
848    /// Borrow as a Bitset tensor.
849    pub fn as_bitset(&self) -> Result<&Tensor<Bitset>> {
850        match self {
851            TensorValue::Bitset(tensor) => Ok(tensor),
852            _ => Err(anyhow!("expected bitset tensor")),
853        }
854    }
855
856    /// Borrow as an F16 tensor.
857    pub fn as_f16(&self) -> Result<&Tensor<F16>> {
858        match self {
859            TensorValue::F16(tensor) => Ok(tensor),
860            _ => Err(anyhow!("expected f16 tensor")),
861        }
862    }
863
864    /// Borrow as a BF16 tensor.
865    pub fn as_bf16(&self) -> Result<&Tensor<BF16>> {
866        match self {
867            TensorValue::BF16(tensor) => Ok(tensor),
868            _ => Err(anyhow!("expected bf16 tensor")),
869        }
870    }
871
872    /// Borrow as an F8 tensor.
873    pub fn as_f8(&self) -> Result<&Tensor<F8>> {
874        match self {
875            TensorValue::F8(tensor) => Ok(tensor),
876            _ => Err(anyhow!("expected f8 tensor")),
877        }
878    }
879
880    /// Borrow as an I4 tensor.
881    pub fn as_i4(&self) -> Result<&Tensor<I4>> {
882        match self {
883            TensorValue::I4(tensor) => Ok(tensor),
884            _ => Err(anyhow!("expected i4 tensor")),
885        }
886    }
887
888    /// Borrow as an I2 tensor.
889    pub fn as_i2(&self) -> Result<&Tensor<I2>> {
890        match self {
891            TensorValue::I2(tensor) => Ok(tensor),
892            _ => Err(anyhow!("expected i2 tensor")),
893        }
894    }
895
896    /// Borrow as an I1 tensor.
897    pub fn as_i1(&self) -> Result<&Tensor<I1>> {
898        match self {
899            TensorValue::I1(tensor) => Ok(tensor),
900            _ => Err(anyhow!("expected i1 tensor")),
901        }
902    }
903
904    /// Borrow as a U4 tensor.
905    pub fn as_u4(&self) -> Result<&Tensor<U4>> {
906        match self {
907            TensorValue::U4(tensor) => Ok(tensor),
908            _ => Err(anyhow!("expected u4 tensor")),
909        }
910    }
911
912    /// Borrow as a U2 tensor.
913    pub fn as_u2(&self) -> Result<&Tensor<U2>> {
914        match self {
915            TensorValue::U2(tensor) => Ok(tensor),
916            _ => Err(anyhow!("expected u2 tensor")),
917        }
918    }
919
920    /// Borrow as a U1 tensor.
921    pub fn as_u1(&self) -> Result<&Tensor<U1>> {
922        match self {
923            TensorValue::U1(tensor) => Ok(tensor),
924            _ => Err(anyhow!("expected u1 tensor")),
925        }
926    }
927
928    /// Borrow as a T2 tensor.
929    pub fn as_t2(&self) -> Result<&Tensor<T2>> {
930        match self {
931            TensorValue::T2(tensor) => Ok(tensor),
932            _ => Err(anyhow!("expected t2 tensor")),
933        }
934    }
935
936    /// Borrow as a T1 tensor.
937    pub fn as_t1(&self) -> Result<&Tensor<T1>> {
938        match self {
939            TensorValue::T1(tensor) => Ok(tensor),
940            _ => Err(anyhow!("expected t1 tensor")),
941        }
942    }
943}
944
945impl From<Tensor<i8>> for TensorValue {
946    fn from(value: Tensor<i8>) -> Self {
947        TensorValue::I8(value)
948    }
949}
950
951impl From<Tensor<i16>> for TensorValue {
952    fn from(value: Tensor<i16>) -> Self {
953        TensorValue::I16(value)
954    }
955}
956
957impl From<Tensor<f32>> for TensorValue {
958    fn from(value: Tensor<f32>) -> Self {
959        TensorValue::F32(value)
960    }
961}
962
963impl From<Tensor<f64>> for TensorValue {
964    fn from(value: Tensor<f64>) -> Self {
965        TensorValue::F64(value)
966    }
967}
968
969impl From<Tensor<BF16>> for TensorValue {
970    fn from(value: Tensor<BF16>) -> Self {
971        TensorValue::BF16(value)
972    }
973}
974
975impl From<Tensor<F8>> for TensorValue {
976    fn from(value: Tensor<F8>) -> Self {
977        TensorValue::F8(value)
978    }
979}
980
981impl From<Tensor<I4>> for TensorValue {
982    fn from(value: Tensor<I4>) -> Self {
983        TensorValue::I4(value)
984    }
985}
986
987impl From<Tensor<I2>> for TensorValue {
988    fn from(value: Tensor<I2>) -> Self {
989        TensorValue::I2(value)
990    }
991}
992
993impl From<Tensor<I1>> for TensorValue {
994    fn from(value: Tensor<I1>) -> Self {
995        TensorValue::I1(value)
996    }
997}
998
999impl From<Tensor<U4>> for TensorValue {
1000    fn from(value: Tensor<U4>) -> Self {
1001        TensorValue::U4(value)
1002    }
1003}
1004
1005impl From<Tensor<U2>> for TensorValue {
1006    fn from(value: Tensor<U2>) -> Self {
1007        TensorValue::U2(value)
1008    }
1009}
1010
1011impl From<Tensor<U1>> for TensorValue {
1012    fn from(value: Tensor<U1>) -> Self {
1013        TensorValue::U1(value)
1014    }
1015}
1016
1017impl From<Tensor<T2>> for TensorValue {
1018    fn from(value: Tensor<T2>) -> Self {
1019        TensorValue::T2(value)
1020    }
1021}
1022
1023impl From<Tensor<T1>> for TensorValue {
1024    fn from(value: Tensor<T1>) -> Self {
1025        TensorValue::T1(value)
1026    }
1027}
1028
1029impl From<Tensor<i32>> for TensorValue {
1030    fn from(value: Tensor<i32>) -> Self {
1031        TensorValue::I32(value)
1032    }
1033}
1034
1035impl From<Tensor<i64>> for TensorValue {
1036    fn from(value: Tensor<i64>) -> Self {
1037        TensorValue::I64(value)
1038    }
1039}
1040
1041impl From<Tensor<u8>> for TensorValue {
1042    fn from(value: Tensor<u8>) -> Self {
1043        TensorValue::U8(value)
1044    }
1045}
1046
1047impl From<Tensor<u16>> for TensorValue {
1048    fn from(value: Tensor<u16>) -> Self {
1049        TensorValue::U16(value)
1050    }
1051}
1052
1053impl From<Tensor<u32>> for TensorValue {
1054    fn from(value: Tensor<u32>) -> Self {
1055        TensorValue::U32(value)
1056    }
1057}
1058
1059impl From<Tensor<u64>> for TensorValue {
1060    fn from(value: Tensor<u64>) -> Self {
1061        TensorValue::U64(value)
1062    }
1063}
1064
1065impl From<Tensor<bool>> for TensorValue {
1066    fn from(value: Tensor<bool>) -> Self {
1067        TensorValue::Bool(value)
1068    }
1069}
1070
1071impl From<Tensor<Bitset>> for TensorValue {
1072    fn from(value: Tensor<Bitset>) -> Self {
1073        TensorValue::Bitset(value)
1074    }
1075}
1076
1077impl From<Tensor<F16>> for TensorValue {
1078    fn from(value: Tensor<F16>) -> Self {
1079        TensorValue::F16(value)
1080    }
1081}
1082
1083impl From<i8> for TensorValue {
1084    fn from(value: i8) -> Self {
1085        TensorValue::I8(Tensor::from_scalar(value))
1086    }
1087}
1088
1089impl From<i16> for TensorValue {
1090    fn from(value: i16) -> Self {
1091        TensorValue::I16(Tensor::from_scalar(value))
1092    }
1093}
1094
1095impl From<i32> for TensorValue {
1096    fn from(value: i32) -> Self {
1097        TensorValue::I32(Tensor::from_scalar(value))
1098    }
1099}
1100
1101impl From<i64> for TensorValue {
1102    fn from(value: i64) -> Self {
1103        TensorValue::I64(Tensor::from_scalar(value))
1104    }
1105}
1106
1107impl From<u8> for TensorValue {
1108    fn from(value: u8) -> Self {
1109        TensorValue::U8(Tensor::from_scalar(value))
1110    }
1111}
1112
1113impl From<u16> for TensorValue {
1114    fn from(value: u16) -> Self {
1115        TensorValue::U16(Tensor::from_scalar(value))
1116    }
1117}
1118
1119impl From<u32> for TensorValue {
1120    fn from(value: u32) -> Self {
1121        TensorValue::U32(Tensor::from_scalar(value))
1122    }
1123}
1124
1125impl From<u64> for TensorValue {
1126    fn from(value: u64) -> Self {
1127        TensorValue::U64(Tensor::from_scalar(value))
1128    }
1129}
1130
1131impl From<f32> for TensorValue {
1132    fn from(value: f32) -> Self {
1133        TensorValue::F32(Tensor::from_scalar(value))
1134    }
1135}
1136
1137impl From<f64> for TensorValue {
1138    fn from(value: f64) -> Self {
1139        TensorValue::F64(Tensor::from_scalar(value))
1140    }
1141}
1142
1143impl From<bool> for TensorValue {
1144    fn from(value: bool) -> Self {
1145        TensorValue::Bool(Tensor::from_scalar(value))
1146    }
1147}
1148
1149impl From<Bitset> for TensorValue {
1150    fn from(value: Bitset) -> Self {
1151        TensorValue::Bitset(Tensor::from_scalar(value))
1152    }
1153}
1154
1155impl From<F16> for TensorValue {
1156    fn from(value: F16) -> Self {
1157        TensorValue::F16(Tensor::from_scalar(value))
1158    }
1159}
1160
1161impl From<BF16> for TensorValue {
1162    fn from(value: BF16) -> Self {
1163        TensorValue::BF16(Tensor::from_scalar(value))
1164    }
1165}
1166
1167impl From<F8> for TensorValue {
1168    fn from(value: F8) -> Self {
1169        TensorValue::F8(Tensor::from_scalar(value))
1170    }
1171}
1172
1173impl From<I4> for TensorValue {
1174    fn from(value: I4) -> Self {
1175        TensorValue::I4(Tensor::from_scalar(value))
1176    }
1177}
1178
1179impl From<I2> for TensorValue {
1180    fn from(value: I2) -> Self {
1181        TensorValue::I2(Tensor::from_scalar(value))
1182    }
1183}
1184
1185impl From<I1> for TensorValue {
1186    fn from(value: I1) -> Self {
1187        TensorValue::I1(Tensor::from_scalar(value))
1188    }
1189}
1190
1191impl From<U4> for TensorValue {
1192    fn from(value: U4) -> Self {
1193        TensorValue::U4(Tensor::from_scalar(value))
1194    }
1195}
1196
1197impl From<U2> for TensorValue {
1198    fn from(value: U2) -> Self {
1199        TensorValue::U2(Tensor::from_scalar(value))
1200    }
1201}
1202
1203impl From<U1> for TensorValue {
1204    fn from(value: U1) -> Self {
1205        TensorValue::U1(Tensor::from_scalar(value))
1206    }
1207}
1208
1209impl From<T2> for TensorValue {
1210    fn from(value: T2) -> Self {
1211        TensorValue::T2(Tensor::from_scalar(value))
1212    }
1213}
1214
1215impl From<T1> for TensorValue {
1216    fn from(value: T1) -> Self {
1217        TensorValue::T1(Tensor::from_scalar(value))
1218    }
1219}