1use anyhow::{anyhow, Result};
2use serde::{Deserialize, Serialize};
3
4use super::{
5 numel, Bitset, BF16, F16, F8, I1, I2, I4, T1, T2, U1, U2, U4, Tensor, TensorOptions,
6};
7
8pub trait TensorElement: Sized + Clone {
10 fn from_value(value: &TensorValue) -> Option<Tensor<Self>>;
12 fn into_value(tensor: Tensor<Self>) -> TensorValue;
14}
15
16impl<T> From<Vec<T>> for Tensor<T> {
17 fn from(value: Vec<T>) -> Self {
18 Tensor::new(value)
19 }
20}
21
22impl TensorElement for f32 {
23 fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
24 match value {
25 TensorValue::F32(tensor) => Some(tensor.clone()),
26 _ => None,
27 }
28 }
29
30 fn into_value(tensor: Tensor<Self>) -> TensorValue {
31 TensorValue::F32(tensor)
32 }
33}
34
35impl TensorElement for f64 {
36 fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
37 match value {
38 TensorValue::F64(tensor) => Some(tensor.clone()),
39 _ => None,
40 }
41 }
42
43 fn into_value(tensor: Tensor<Self>) -> TensorValue {
44 TensorValue::F64(tensor)
45 }
46}
47
48impl TensorElement for i8 {
49 fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
50 match value {
51 TensorValue::I8(tensor) => Some(tensor.clone()),
52 _ => None,
53 }
54 }
55
56 fn into_value(tensor: Tensor<Self>) -> TensorValue {
57 TensorValue::I8(tensor)
58 }
59}
60
61impl TensorElement for i16 {
62 fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
63 match value {
64 TensorValue::I16(tensor) => Some(tensor.clone()),
65 _ => None,
66 }
67 }
68
69 fn into_value(tensor: Tensor<Self>) -> TensorValue {
70 TensorValue::I16(tensor)
71 }
72}
73
74impl TensorElement for i32 {
75 fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
76 match value {
77 TensorValue::I32(tensor) => Some(tensor.clone()),
78 _ => None,
79 }
80 }
81
82 fn into_value(tensor: Tensor<Self>) -> TensorValue {
83 TensorValue::I32(tensor)
84 }
85}
86
87impl TensorElement for i64 {
88 fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
89 match value {
90 TensorValue::I64(tensor) => Some(tensor.clone()),
91 _ => None,
92 }
93 }
94
95 fn into_value(tensor: Tensor<Self>) -> TensorValue {
96 TensorValue::I64(tensor)
97 }
98}
99
100impl TensorElement for u8 {
101 fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
102 match value {
103 TensorValue::U8(tensor) => Some(tensor.clone()),
104 _ => None,
105 }
106 }
107
108 fn into_value(tensor: Tensor<Self>) -> TensorValue {
109 TensorValue::U8(tensor)
110 }
111}
112
113impl TensorElement for u16 {
114 fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
115 match value {
116 TensorValue::U16(tensor) => Some(tensor.clone()),
117 _ => None,
118 }
119 }
120
121 fn into_value(tensor: Tensor<Self>) -> TensorValue {
122 TensorValue::U16(tensor)
123 }
124}
125
126impl TensorElement for u32 {
127 fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
128 match value {
129 TensorValue::U32(tensor) => Some(tensor.clone()),
130 _ => None,
131 }
132 }
133
134 fn into_value(tensor: Tensor<Self>) -> TensorValue {
135 TensorValue::U32(tensor)
136 }
137}
138
139impl TensorElement for u64 {
140 fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
141 match value {
142 TensorValue::U64(tensor) => Some(tensor.clone()),
143 _ => None,
144 }
145 }
146
147 fn into_value(tensor: Tensor<Self>) -> TensorValue {
148 TensorValue::U64(tensor)
149 }
150}
151
152impl TensorElement for bool {
153 fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
154 match value {
155 TensorValue::Bool(tensor) => Some(tensor.clone()),
156 _ => None,
157 }
158 }
159
160 fn into_value(tensor: Tensor<Self>) -> TensorValue {
161 TensorValue::Bool(tensor)
162 }
163}
164
165impl TensorElement for F16 {
166 fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
167 match value {
168 TensorValue::F16(tensor) => Some(tensor.clone()),
169 _ => None,
170 }
171 }
172
173 fn into_value(tensor: Tensor<Self>) -> TensorValue {
174 TensorValue::F16(tensor)
175 }
176}
177
178impl TensorElement for BF16 {
179 fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
180 match value {
181 TensorValue::BF16(tensor) => Some(tensor.clone()),
182 _ => None,
183 }
184 }
185
186 fn into_value(tensor: Tensor<Self>) -> TensorValue {
187 TensorValue::BF16(tensor)
188 }
189}
190
191impl TensorElement for F8 {
192 fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
193 match value {
194 TensorValue::F8(tensor) => Some(tensor.clone()),
195 _ => None,
196 }
197 }
198
199 fn into_value(tensor: Tensor<Self>) -> TensorValue {
200 TensorValue::F8(tensor)
201 }
202}
203
204impl TensorElement for I4 {
205 fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
206 match value {
207 TensorValue::I4(tensor) => Some(tensor.clone()),
208 _ => None,
209 }
210 }
211
212 fn into_value(tensor: Tensor<Self>) -> TensorValue {
213 TensorValue::I4(tensor)
214 }
215}
216
217impl TensorElement for I2 {
218 fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
219 match value {
220 TensorValue::I2(tensor) => Some(tensor.clone()),
221 _ => None,
222 }
223 }
224
225 fn into_value(tensor: Tensor<Self>) -> TensorValue {
226 TensorValue::I2(tensor)
227 }
228}
229
230impl TensorElement for I1 {
231 fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
232 match value {
233 TensorValue::I1(tensor) => Some(tensor.clone()),
234 _ => None,
235 }
236 }
237
238 fn into_value(tensor: Tensor<Self>) -> TensorValue {
239 TensorValue::I1(tensor)
240 }
241}
242
243impl TensorElement for U4 {
244 fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
245 match value {
246 TensorValue::U4(tensor) => Some(tensor.clone()),
247 _ => None,
248 }
249 }
250
251 fn into_value(tensor: Tensor<Self>) -> TensorValue {
252 TensorValue::U4(tensor)
253 }
254}
255
256impl TensorElement for U2 {
257 fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
258 match value {
259 TensorValue::U2(tensor) => Some(tensor.clone()),
260 _ => None,
261 }
262 }
263
264 fn into_value(tensor: Tensor<Self>) -> TensorValue {
265 TensorValue::U2(tensor)
266 }
267}
268
269impl TensorElement for U1 {
270 fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
271 match value {
272 TensorValue::U1(tensor) => Some(tensor.clone()),
273 _ => None,
274 }
275 }
276
277 fn into_value(tensor: Tensor<Self>) -> TensorValue {
278 TensorValue::U1(tensor)
279 }
280}
281
282impl TensorElement for T2 {
283 fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
284 match value {
285 TensorValue::T2(tensor) => Some(tensor.clone()),
286 _ => None,
287 }
288 }
289
290 fn into_value(tensor: Tensor<Self>) -> TensorValue {
291 TensorValue::T2(tensor)
292 }
293}
294
295impl TensorElement for T1 {
296 fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
297 match value {
298 TensorValue::T1(tensor) => Some(tensor.clone()),
299 _ => None,
300 }
301 }
302
303 fn into_value(tensor: Tensor<Self>) -> TensorValue {
304 TensorValue::T1(tensor)
305 }
306}
307
308impl TensorElement for Bitset {
309 fn from_value(value: &TensorValue) -> Option<Tensor<Self>> {
310 match value {
311 TensorValue::Bitset(tensor) => Some(tensor.clone()),
312 _ => None,
313 }
314 }
315
316 fn into_value(tensor: Tensor<Self>) -> TensorValue {
317 TensorValue::Bitset(tensor)
318 }
319}
320
321#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
323pub enum DType {
324 I8,
325 I16,
326 F32,
327 F64,
328 U8,
329 U16,
330 I32,
331 I64,
332 U32,
333 U64,
334 Bool,
335 Bitset,
336 F16,
337 BF16,
338 F8,
339 I4,
340 I2,
341 I1,
342 U4,
343 U2,
344 U1,
345 T2,
346 T1,
347}
348
349impl DType {
350 pub fn from_ident(ident: &str) -> Result<Self> {
352 match ident {
353 "i8" => Ok(DType::I8),
354 "i16" => Ok(DType::I16),
355 "f32" => Ok(DType::F32),
356 "f64" => Ok(DType::F64),
357 "u8" => Ok(DType::U8),
358 "u16" => Ok(DType::U16),
359 "i32" => Ok(DType::I32),
360 "i64" => Ok(DType::I64),
361 "u32" => Ok(DType::U32),
362 "u64" => Ok(DType::U64),
363 "bool" => Ok(DType::Bool),
364 "bitset" => Ok(DType::Bitset),
365 "f16" => Ok(DType::F16),
366 "bf16" => Ok(DType::BF16),
367 "f8" | "f8e5m2" | "float8e5m2" => Ok(DType::F8),
368 "i4" => Ok(DType::I4),
369 "i2" => Ok(DType::I2),
370 "i1" => Ok(DType::I1),
371 "u4" => Ok(DType::U4),
372 "u2" => Ok(DType::U2),
373 "u1" => Ok(DType::U1),
374 "t2" => Ok(DType::T2),
375 "t1" => Ok(DType::T1),
376 _ => Err(anyhow!("unsupported dtype: {}", ident)),
377 }
378 }
379
380 pub fn is_universal(self) -> bool {
382 matches!(
383 self,
384 DType::F64
385 | DType::F32
386 | DType::I64
387 | DType::I32
388 | DType::I16
389 | DType::I8
390 | DType::U64
391 | DType::U32
392 | DType::U16
393 | DType::U8
394 | DType::Bool
395 )
396 }
397
398 pub fn is_packed(self) -> bool {
400 matches!(
401 self,
402 DType::I1
403 | DType::I2
404 | DType::I4
405 | DType::U1
406 | DType::U2
407 | DType::U4
408 | DType::T1
409 | DType::T2
410 )
411 }
412
413 pub fn is_float(self) -> bool {
415 matches!(self, DType::F8 | DType::F16 | DType::BF16 | DType::F32 | DType::F64)
416 }
417
418 pub fn is_signed_int(self) -> bool {
420 matches!(self, DType::I8 | DType::I16 | DType::I32 | DType::I64)
421 }
422
423 pub fn is_packed_signed(self) -> bool {
425 matches!(self, DType::I1 | DType::I2 | DType::I4)
426 }
427
428 pub fn bit_width(self) -> u8 {
430 match self {
431 DType::I1 => 1,
432 DType::I2 => 2,
433 DType::I4 => 4,
434 DType::U1 => 1,
435 DType::U2 => 2,
436 DType::U4 => 4,
437 DType::T1 => 1,
438 DType::T2 => 2,
439 DType::I8 | DType::U8 | DType::Bool => 8,
440 DType::I16 | DType::U16 | DType::F16 | DType::BF16 => 16,
441 DType::I32 | DType::U32 | DType::F32 => 32,
442 DType::I64 | DType::U64 | DType::F64 => 64,
443 DType::F8 => 8,
444 DType::Bitset => 8,
445 }
446 }
447
448 pub fn storage_len(self, logical_len: usize) -> usize {
450 if self.is_packed() {
451 let bits = logical_len.saturating_mul(self.bit_width() as usize);
452 (bits + 7) / 8
453 } else {
454 logical_len
455 }
456 }
457}
458
459#[derive(Debug, Clone)]
461pub enum TensorValue {
462 I8(Tensor<i8>),
463 I16(Tensor<i16>),
464 F32(Tensor<f32>),
465 F64(Tensor<f64>),
466 U8(Tensor<u8>),
467 U16(Tensor<u16>),
468 I32(Tensor<i32>),
469 I64(Tensor<i64>),
470 U32(Tensor<u32>),
471 U64(Tensor<u64>),
472 Bool(Tensor<bool>),
473 Bitset(Tensor<Bitset>),
474 F16(Tensor<F16>),
475 BF16(Tensor<BF16>),
476 F8(Tensor<F8>),
477 I4(Tensor<I4>),
478 I2(Tensor<I2>),
479 I1(Tensor<I1>),
480 U4(Tensor<U4>),
481 U2(Tensor<U2>),
482 U1(Tensor<U1>),
483 T2(Tensor<T2>),
484 T1(Tensor<T1>),
485}
486
487unsafe impl Send for TensorValue {}
489
490impl TensorValue {
491 pub fn dtype(&self) -> DType {
493 match self {
494 TensorValue::I8(_) => DType::I8,
495 TensorValue::I16(_) => DType::I16,
496 TensorValue::F32(_) => DType::F32,
497 TensorValue::F64(_) => DType::F64,
498 TensorValue::U8(_) => DType::U8,
499 TensorValue::U16(_) => DType::U16,
500 TensorValue::I32(_) => DType::I32,
501 TensorValue::I64(_) => DType::I64,
502 TensorValue::U32(_) => DType::U32,
503 TensorValue::U64(_) => DType::U64,
504 TensorValue::Bool(_) => DType::Bool,
505 TensorValue::Bitset(_) => DType::Bitset,
506 TensorValue::F16(_) => DType::F16,
507 TensorValue::BF16(_) => DType::BF16,
508 TensorValue::F8(_) => DType::F8,
509 TensorValue::I4(_) => DType::I4,
510 TensorValue::I2(_) => DType::I2,
511 TensorValue::I1(_) => DType::I1,
512 TensorValue::U4(_) => DType::U4,
513 TensorValue::U2(_) => DType::U2,
514 TensorValue::U1(_) => DType::U1,
515 TensorValue::T2(_) => DType::T2,
516 TensorValue::T1(_) => DType::T1,
517 }
518 }
519
520 pub fn len(&self) -> usize {
522 numel(self.shape())
523 }
524
525 pub fn shape(&self) -> &[usize] {
527 match self {
528 TensorValue::I8(tensor) => tensor.shape(),
529 TensorValue::I16(tensor) => tensor.shape(),
530 TensorValue::F32(tensor) => tensor.shape(),
531 TensorValue::F64(tensor) => tensor.shape(),
532 TensorValue::U8(tensor) => tensor.shape(),
533 TensorValue::U16(tensor) => tensor.shape(),
534 TensorValue::I32(tensor) => tensor.shape(),
535 TensorValue::I64(tensor) => tensor.shape(),
536 TensorValue::U32(tensor) => tensor.shape(),
537 TensorValue::U64(tensor) => tensor.shape(),
538 TensorValue::Bool(tensor) => tensor.shape(),
539 TensorValue::Bitset(tensor) => tensor.shape(),
540 TensorValue::F16(tensor) => tensor.shape(),
541 TensorValue::BF16(tensor) => tensor.shape(),
542 TensorValue::F8(tensor) => tensor.shape(),
543 TensorValue::I4(tensor) => tensor.shape(),
544 TensorValue::I2(tensor) => tensor.shape(),
545 TensorValue::I1(tensor) => tensor.shape(),
546 TensorValue::U4(tensor) => tensor.shape(),
547 TensorValue::U2(tensor) => tensor.shape(),
548 TensorValue::U1(tensor) => tensor.shape(),
549 TensorValue::T2(tensor) => tensor.shape(),
550 TensorValue::T1(tensor) => tensor.shape(),
551 }
552 }
553
554 pub fn strides(&self) -> &[usize] {
556 match self {
557 TensorValue::I8(tensor) => tensor.strides(),
558 TensorValue::I16(tensor) => tensor.strides(),
559 TensorValue::F32(tensor) => tensor.strides(),
560 TensorValue::F64(tensor) => tensor.strides(),
561 TensorValue::U8(tensor) => tensor.strides(),
562 TensorValue::U16(tensor) => tensor.strides(),
563 TensorValue::I32(tensor) => tensor.strides(),
564 TensorValue::I64(tensor) => tensor.strides(),
565 TensorValue::U32(tensor) => tensor.strides(),
566 TensorValue::U64(tensor) => tensor.strides(),
567 TensorValue::Bool(tensor) => tensor.strides(),
568 TensorValue::Bitset(tensor) => tensor.strides(),
569 TensorValue::F16(tensor) => tensor.strides(),
570 TensorValue::BF16(tensor) => tensor.strides(),
571 TensorValue::F8(tensor) => tensor.strides(),
572 TensorValue::I4(tensor) => tensor.strides(),
573 TensorValue::I2(tensor) => tensor.strides(),
574 TensorValue::I1(tensor) => tensor.strides(),
575 TensorValue::U4(tensor) => tensor.strides(),
576 TensorValue::U2(tensor) => tensor.strides(),
577 TensorValue::U1(tensor) => tensor.strides(),
578 TensorValue::T2(tensor) => tensor.strides(),
579 TensorValue::T1(tensor) => tensor.strides(),
580 }
581 }
582
583 pub fn zeros(dtype: DType, shape: &[usize]) -> Self {
585 let len = numel(shape);
586 let packed_len = dtype.storage_len(len);
587 match dtype {
588 DType::I8 => TensorValue::I8(
589 Tensor::from_vec_with_opts(vec![0; len], TensorOptions {
590 shape: Some(shape.to_vec()),
591 ..TensorOptions::default()
592 })
593 .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
594 ),
595 DType::I16 => TensorValue::I16(
596 Tensor::from_vec_with_opts(vec![0; len], TensorOptions {
597 shape: Some(shape.to_vec()),
598 ..TensorOptions::default()
599 })
600 .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
601 ),
602 DType::F32 => TensorValue::F32(
603 Tensor::from_vec_with_opts(vec![0.0; len], TensorOptions {
604 shape: Some(shape.to_vec()),
605 ..TensorOptions::default()
606 })
607 .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
608 ),
609 DType::F64 => TensorValue::F64(
610 Tensor::from_vec_with_opts(vec![0.0; len], TensorOptions {
611 shape: Some(shape.to_vec()),
612 ..TensorOptions::default()
613 })
614 .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
615 ),
616 DType::U8 => TensorValue::U8(
617 Tensor::from_vec_with_opts(vec![0; len], TensorOptions {
618 shape: Some(shape.to_vec()),
619 ..TensorOptions::default()
620 })
621 .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
622 ),
623 DType::U16 => TensorValue::U16(
624 Tensor::from_vec_with_opts(vec![0; len], TensorOptions {
625 shape: Some(shape.to_vec()),
626 ..TensorOptions::default()
627 })
628 .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
629 ),
630 DType::I32 => TensorValue::I32(
631 Tensor::from_vec_with_opts(vec![0; len], TensorOptions {
632 shape: Some(shape.to_vec()),
633 ..TensorOptions::default()
634 })
635 .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
636 ),
637 DType::I64 => TensorValue::I64(
638 Tensor::from_vec_with_opts(vec![0; len], TensorOptions {
639 shape: Some(shape.to_vec()),
640 ..TensorOptions::default()
641 })
642 .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
643 ),
644 DType::U32 => TensorValue::U32(
645 Tensor::from_vec_with_opts(vec![0; len], TensorOptions {
646 shape: Some(shape.to_vec()),
647 ..TensorOptions::default()
648 })
649 .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
650 ),
651 DType::U64 => TensorValue::U64(
652 Tensor::from_vec_with_opts(vec![0; len], TensorOptions {
653 shape: Some(shape.to_vec()),
654 ..TensorOptions::default()
655 })
656 .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
657 ),
658 DType::Bool => TensorValue::Bool(
659 Tensor::from_vec_with_opts(vec![false; len], TensorOptions {
660 shape: Some(shape.to_vec()),
661 ..TensorOptions::default()
662 })
663 .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
664 ),
665 DType::Bitset => TensorValue::Bitset(
666 Tensor::from_vec_with_opts(vec![Bitset { bits: 0 }; len], TensorOptions {
667 shape: Some(shape.to_vec()),
668 ..TensorOptions::default()
669 })
670 .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
671 ),
672 DType::F16 => TensorValue::F16(
673 Tensor::from_vec_with_opts(vec![F16 { bits: 0 }; len], TensorOptions {
674 shape: Some(shape.to_vec()),
675 ..TensorOptions::default()
676 })
677 .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
678 ),
679 DType::BF16 => TensorValue::BF16(
680 Tensor::from_vec_with_opts(vec![BF16 { bits: 0 }; len], TensorOptions {
681 shape: Some(shape.to_vec()),
682 ..TensorOptions::default()
683 })
684 .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
685 ),
686 DType::F8 => TensorValue::F8(
687 Tensor::from_vec_with_opts(vec![F8 { bits: 0 }; len], TensorOptions {
688 shape: Some(shape.to_vec()),
689 ..TensorOptions::default()
690 })
691 .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
692 ),
693 DType::I4 => TensorValue::I4(
694 Tensor::from_vec_with_opts(vec![I4 { bits: 0 }; packed_len], TensorOptions {
695 shape: Some(shape.to_vec()),
696 allow_len_mismatch: true,
697 ..TensorOptions::default()
698 })
699 .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
700 ),
701 DType::I2 => TensorValue::I2(
702 Tensor::from_vec_with_opts(vec![I2 { bits: 0 }; packed_len], TensorOptions {
703 shape: Some(shape.to_vec()),
704 allow_len_mismatch: true,
705 ..TensorOptions::default()
706 })
707 .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
708 ),
709 DType::I1 => TensorValue::I1(
710 Tensor::from_vec_with_opts(vec![I1 { bits: 0 }; packed_len], TensorOptions {
711 shape: Some(shape.to_vec()),
712 allow_len_mismatch: true,
713 ..TensorOptions::default()
714 })
715 .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
716 ),
717 DType::U4 => TensorValue::U4(
718 Tensor::from_vec_with_opts(vec![U4 { bits: 0 }; packed_len], TensorOptions {
719 shape: Some(shape.to_vec()),
720 allow_len_mismatch: true,
721 ..TensorOptions::default()
722 })
723 .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
724 ),
725 DType::U2 => TensorValue::U2(
726 Tensor::from_vec_with_opts(vec![U2 { bits: 0 }; packed_len], TensorOptions {
727 shape: Some(shape.to_vec()),
728 allow_len_mismatch: true,
729 ..TensorOptions::default()
730 })
731 .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
732 ),
733 DType::U1 => TensorValue::U1(
734 Tensor::from_vec_with_opts(vec![U1 { bits: 0 }; packed_len], TensorOptions {
735 shape: Some(shape.to_vec()),
736 allow_len_mismatch: true,
737 ..TensorOptions::default()
738 })
739 .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
740 ),
741 DType::T2 => TensorValue::T2(
742 Tensor::from_vec_with_opts(vec![T2 { bits: 0 }; packed_len], TensorOptions {
743 shape: Some(shape.to_vec()),
744 allow_len_mismatch: true,
745 ..TensorOptions::default()
746 })
747 .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
748 ),
749 DType::T1 => TensorValue::T1(
750 Tensor::from_vec_with_opts(vec![T1 { bits: 0 }; packed_len], TensorOptions {
751 shape: Some(shape.to_vec()),
752 allow_len_mismatch: true,
753 ..TensorOptions::default()
754 })
755 .unwrap_or_else(|err| panic!("tensor zeros failed: {}", err)),
756 ),
757 }
758 }
759
760 pub fn as_i8(&self) -> Result<&Tensor<i8>> {
762 match self {
763 TensorValue::I8(tensor) => Ok(tensor),
764 _ => Err(anyhow!("expected i8 tensor")),
765 }
766 }
767
768 pub fn as_i16(&self) -> Result<&Tensor<i16>> {
770 match self {
771 TensorValue::I16(tensor) => Ok(tensor),
772 _ => Err(anyhow!("expected i16 tensor")),
773 }
774 }
775
776 pub fn as_f32(&self) -> Result<&Tensor<f32>> {
778 match self {
779 TensorValue::F32(tensor) => Ok(tensor),
780 _ => Err(anyhow!("expected f32 tensor")),
781 }
782 }
783
784 pub fn as_f64(&self) -> Result<&Tensor<f64>> {
786 match self {
787 TensorValue::F64(tensor) => Ok(tensor),
788 _ => Err(anyhow!("expected f64 tensor")),
789 }
790 }
791
792 pub fn as_u8(&self) -> Result<&Tensor<u8>> {
794 match self {
795 TensorValue::U8(tensor) => Ok(tensor),
796 _ => Err(anyhow!("expected u8 tensor")),
797 }
798 }
799
800 pub fn as_u16(&self) -> Result<&Tensor<u16>> {
802 match self {
803 TensorValue::U16(tensor) => Ok(tensor),
804 _ => Err(anyhow!("expected u16 tensor")),
805 }
806 }
807
808 pub fn as_i32(&self) -> Result<&Tensor<i32>> {
810 match self {
811 TensorValue::I32(tensor) => Ok(tensor),
812 _ => Err(anyhow!("expected i32 tensor")),
813 }
814 }
815
816 pub fn as_i64(&self) -> Result<&Tensor<i64>> {
818 match self {
819 TensorValue::I64(tensor) => Ok(tensor),
820 _ => Err(anyhow!("expected i64 tensor")),
821 }
822 }
823
824 pub fn as_u32(&self) -> Result<&Tensor<u32>> {
826 match self {
827 TensorValue::U32(tensor) => Ok(tensor),
828 _ => Err(anyhow!("expected u32 tensor")),
829 }
830 }
831
832 pub fn as_u64(&self) -> Result<&Tensor<u64>> {
834 match self {
835 TensorValue::U64(tensor) => Ok(tensor),
836 _ => Err(anyhow!("expected u64 tensor")),
837 }
838 }
839
840 pub fn as_bool(&self) -> Result<&Tensor<bool>> {
842 match self {
843 TensorValue::Bool(tensor) => Ok(tensor),
844 _ => Err(anyhow!("expected bool tensor")),
845 }
846 }
847
848 pub fn as_bitset(&self) -> Result<&Tensor<Bitset>> {
850 match self {
851 TensorValue::Bitset(tensor) => Ok(tensor),
852 _ => Err(anyhow!("expected bitset tensor")),
853 }
854 }
855
856 pub fn as_f16(&self) -> Result<&Tensor<F16>> {
858 match self {
859 TensorValue::F16(tensor) => Ok(tensor),
860 _ => Err(anyhow!("expected f16 tensor")),
861 }
862 }
863
864 pub fn as_bf16(&self) -> Result<&Tensor<BF16>> {
866 match self {
867 TensorValue::BF16(tensor) => Ok(tensor),
868 _ => Err(anyhow!("expected bf16 tensor")),
869 }
870 }
871
872 pub fn as_f8(&self) -> Result<&Tensor<F8>> {
874 match self {
875 TensorValue::F8(tensor) => Ok(tensor),
876 _ => Err(anyhow!("expected f8 tensor")),
877 }
878 }
879
880 pub fn as_i4(&self) -> Result<&Tensor<I4>> {
882 match self {
883 TensorValue::I4(tensor) => Ok(tensor),
884 _ => Err(anyhow!("expected i4 tensor")),
885 }
886 }
887
888 pub fn as_i2(&self) -> Result<&Tensor<I2>> {
890 match self {
891 TensorValue::I2(tensor) => Ok(tensor),
892 _ => Err(anyhow!("expected i2 tensor")),
893 }
894 }
895
896 pub fn as_i1(&self) -> Result<&Tensor<I1>> {
898 match self {
899 TensorValue::I1(tensor) => Ok(tensor),
900 _ => Err(anyhow!("expected i1 tensor")),
901 }
902 }
903
904 pub fn as_u4(&self) -> Result<&Tensor<U4>> {
906 match self {
907 TensorValue::U4(tensor) => Ok(tensor),
908 _ => Err(anyhow!("expected u4 tensor")),
909 }
910 }
911
912 pub fn as_u2(&self) -> Result<&Tensor<U2>> {
914 match self {
915 TensorValue::U2(tensor) => Ok(tensor),
916 _ => Err(anyhow!("expected u2 tensor")),
917 }
918 }
919
920 pub fn as_u1(&self) -> Result<&Tensor<U1>> {
922 match self {
923 TensorValue::U1(tensor) => Ok(tensor),
924 _ => Err(anyhow!("expected u1 tensor")),
925 }
926 }
927
928 pub fn as_t2(&self) -> Result<&Tensor<T2>> {
930 match self {
931 TensorValue::T2(tensor) => Ok(tensor),
932 _ => Err(anyhow!("expected t2 tensor")),
933 }
934 }
935
936 pub fn as_t1(&self) -> Result<&Tensor<T1>> {
938 match self {
939 TensorValue::T1(tensor) => Ok(tensor),
940 _ => Err(anyhow!("expected t1 tensor")),
941 }
942 }
943}
944
945impl From<Tensor<i8>> for TensorValue {
946 fn from(value: Tensor<i8>) -> Self {
947 TensorValue::I8(value)
948 }
949}
950
951impl From<Tensor<i16>> for TensorValue {
952 fn from(value: Tensor<i16>) -> Self {
953 TensorValue::I16(value)
954 }
955}
956
957impl From<Tensor<f32>> for TensorValue {
958 fn from(value: Tensor<f32>) -> Self {
959 TensorValue::F32(value)
960 }
961}
962
963impl From<Tensor<f64>> for TensorValue {
964 fn from(value: Tensor<f64>) -> Self {
965 TensorValue::F64(value)
966 }
967}
968
969impl From<Tensor<BF16>> for TensorValue {
970 fn from(value: Tensor<BF16>) -> Self {
971 TensorValue::BF16(value)
972 }
973}
974
975impl From<Tensor<F8>> for TensorValue {
976 fn from(value: Tensor<F8>) -> Self {
977 TensorValue::F8(value)
978 }
979}
980
981impl From<Tensor<I4>> for TensorValue {
982 fn from(value: Tensor<I4>) -> Self {
983 TensorValue::I4(value)
984 }
985}
986
987impl From<Tensor<I2>> for TensorValue {
988 fn from(value: Tensor<I2>) -> Self {
989 TensorValue::I2(value)
990 }
991}
992
993impl From<Tensor<I1>> for TensorValue {
994 fn from(value: Tensor<I1>) -> Self {
995 TensorValue::I1(value)
996 }
997}
998
999impl From<Tensor<U4>> for TensorValue {
1000 fn from(value: Tensor<U4>) -> Self {
1001 TensorValue::U4(value)
1002 }
1003}
1004
1005impl From<Tensor<U2>> for TensorValue {
1006 fn from(value: Tensor<U2>) -> Self {
1007 TensorValue::U2(value)
1008 }
1009}
1010
1011impl From<Tensor<U1>> for TensorValue {
1012 fn from(value: Tensor<U1>) -> Self {
1013 TensorValue::U1(value)
1014 }
1015}
1016
1017impl From<Tensor<T2>> for TensorValue {
1018 fn from(value: Tensor<T2>) -> Self {
1019 TensorValue::T2(value)
1020 }
1021}
1022
1023impl From<Tensor<T1>> for TensorValue {
1024 fn from(value: Tensor<T1>) -> Self {
1025 TensorValue::T1(value)
1026 }
1027}
1028
1029impl From<Tensor<i32>> for TensorValue {
1030 fn from(value: Tensor<i32>) -> Self {
1031 TensorValue::I32(value)
1032 }
1033}
1034
1035impl From<Tensor<i64>> for TensorValue {
1036 fn from(value: Tensor<i64>) -> Self {
1037 TensorValue::I64(value)
1038 }
1039}
1040
1041impl From<Tensor<u8>> for TensorValue {
1042 fn from(value: Tensor<u8>) -> Self {
1043 TensorValue::U8(value)
1044 }
1045}
1046
1047impl From<Tensor<u16>> for TensorValue {
1048 fn from(value: Tensor<u16>) -> Self {
1049 TensorValue::U16(value)
1050 }
1051}
1052
1053impl From<Tensor<u32>> for TensorValue {
1054 fn from(value: Tensor<u32>) -> Self {
1055 TensorValue::U32(value)
1056 }
1057}
1058
1059impl From<Tensor<u64>> for TensorValue {
1060 fn from(value: Tensor<u64>) -> Self {
1061 TensorValue::U64(value)
1062 }
1063}
1064
1065impl From<Tensor<bool>> for TensorValue {
1066 fn from(value: Tensor<bool>) -> Self {
1067 TensorValue::Bool(value)
1068 }
1069}
1070
1071impl From<Tensor<Bitset>> for TensorValue {
1072 fn from(value: Tensor<Bitset>) -> Self {
1073 TensorValue::Bitset(value)
1074 }
1075}
1076
1077impl From<Tensor<F16>> for TensorValue {
1078 fn from(value: Tensor<F16>) -> Self {
1079 TensorValue::F16(value)
1080 }
1081}
1082
1083impl From<i8> for TensorValue {
1084 fn from(value: i8) -> Self {
1085 TensorValue::I8(Tensor::from_scalar(value))
1086 }
1087}
1088
1089impl From<i16> for TensorValue {
1090 fn from(value: i16) -> Self {
1091 TensorValue::I16(Tensor::from_scalar(value))
1092 }
1093}
1094
1095impl From<i32> for TensorValue {
1096 fn from(value: i32) -> Self {
1097 TensorValue::I32(Tensor::from_scalar(value))
1098 }
1099}
1100
1101impl From<i64> for TensorValue {
1102 fn from(value: i64) -> Self {
1103 TensorValue::I64(Tensor::from_scalar(value))
1104 }
1105}
1106
1107impl From<u8> for TensorValue {
1108 fn from(value: u8) -> Self {
1109 TensorValue::U8(Tensor::from_scalar(value))
1110 }
1111}
1112
1113impl From<u16> for TensorValue {
1114 fn from(value: u16) -> Self {
1115 TensorValue::U16(Tensor::from_scalar(value))
1116 }
1117}
1118
1119impl From<u32> for TensorValue {
1120 fn from(value: u32) -> Self {
1121 TensorValue::U32(Tensor::from_scalar(value))
1122 }
1123}
1124
1125impl From<u64> for TensorValue {
1126 fn from(value: u64) -> Self {
1127 TensorValue::U64(Tensor::from_scalar(value))
1128 }
1129}
1130
1131impl From<f32> for TensorValue {
1132 fn from(value: f32) -> Self {
1133 TensorValue::F32(Tensor::from_scalar(value))
1134 }
1135}
1136
1137impl From<f64> for TensorValue {
1138 fn from(value: f64) -> Self {
1139 TensorValue::F64(Tensor::from_scalar(value))
1140 }
1141}
1142
1143impl From<bool> for TensorValue {
1144 fn from(value: bool) -> Self {
1145 TensorValue::Bool(Tensor::from_scalar(value))
1146 }
1147}
1148
1149impl From<Bitset> for TensorValue {
1150 fn from(value: Bitset) -> Self {
1151 TensorValue::Bitset(Tensor::from_scalar(value))
1152 }
1153}
1154
1155impl From<F16> for TensorValue {
1156 fn from(value: F16) -> Self {
1157 TensorValue::F16(Tensor::from_scalar(value))
1158 }
1159}
1160
1161impl From<BF16> for TensorValue {
1162 fn from(value: BF16) -> Self {
1163 TensorValue::BF16(Tensor::from_scalar(value))
1164 }
1165}
1166
1167impl From<F8> for TensorValue {
1168 fn from(value: F8) -> Self {
1169 TensorValue::F8(Tensor::from_scalar(value))
1170 }
1171}
1172
1173impl From<I4> for TensorValue {
1174 fn from(value: I4) -> Self {
1175 TensorValue::I4(Tensor::from_scalar(value))
1176 }
1177}
1178
1179impl From<I2> for TensorValue {
1180 fn from(value: I2) -> Self {
1181 TensorValue::I2(Tensor::from_scalar(value))
1182 }
1183}
1184
1185impl From<I1> for TensorValue {
1186 fn from(value: I1) -> Self {
1187 TensorValue::I1(Tensor::from_scalar(value))
1188 }
1189}
1190
1191impl From<U4> for TensorValue {
1192 fn from(value: U4) -> Self {
1193 TensorValue::U4(Tensor::from_scalar(value))
1194 }
1195}
1196
1197impl From<U2> for TensorValue {
1198 fn from(value: U2) -> Self {
1199 TensorValue::U2(Tensor::from_scalar(value))
1200 }
1201}
1202
1203impl From<U1> for TensorValue {
1204 fn from(value: U1) -> Self {
1205 TensorValue::U1(Tensor::from_scalar(value))
1206 }
1207}
1208
1209impl From<T2> for TensorValue {
1210 fn from(value: T2) -> Self {
1211 TensorValue::T2(Tensor::from_scalar(value))
1212 }
1213}
1214
1215impl From<T1> for TensorValue {
1216 fn from(value: T1) -> Self {
1217 TensorValue::T1(Tensor::from_scalar(value))
1218 }
1219}