1use crate::{DType, Device, Error, Tensor, WithDType};
3use half::{bf16, f16, slice::HalfFloatSliceExt};
4use std::convert::TryFrom;
5
6impl<T: WithDType> TryFrom<&Tensor> for Vec<T> {
7 type Error = Error;
8 fn try_from(tensor: &Tensor) -> Result<Self, Self::Error> {
9 tensor.to_vec1::<T>()
10 }
11}
12
13impl<T: WithDType> TryFrom<&Tensor> for Vec<Vec<T>> {
14 type Error = Error;
15 fn try_from(tensor: &Tensor) -> Result<Self, Self::Error> {
16 tensor.to_vec2::<T>()
17 }
18}
19
20impl<T: WithDType> TryFrom<&Tensor> for Vec<Vec<Vec<T>>> {
21 type Error = Error;
22 fn try_from(tensor: &Tensor) -> Result<Self, Self::Error> {
23 tensor.to_vec3::<T>()
24 }
25}
26
27impl<T: WithDType> TryFrom<Tensor> for Vec<T> {
28 type Error = Error;
29 fn try_from(tensor: Tensor) -> Result<Self, Self::Error> {
30 Vec::<T>::try_from(&tensor)
31 }
32}
33
34impl<T: WithDType> TryFrom<Tensor> for Vec<Vec<T>> {
35 type Error = Error;
36 fn try_from(tensor: Tensor) -> Result<Self, Self::Error> {
37 Vec::<Vec<T>>::try_from(&tensor)
38 }
39}
40
41impl<T: WithDType> TryFrom<Tensor> for Vec<Vec<Vec<T>>> {
42 type Error = Error;
43 fn try_from(tensor: Tensor) -> Result<Self, Self::Error> {
44 Vec::<Vec<Vec<T>>>::try_from(&tensor)
45 }
46}
47
48impl<T: WithDType> TryFrom<&[T]> for Tensor {
49 type Error = Error;
50 fn try_from(v: &[T]) -> Result<Self, Self::Error> {
51 Tensor::from_slice(v, v.len(), &Device::Cpu)
52 }
53}
54
55impl<T: WithDType> TryFrom<Vec<T>> for Tensor {
56 type Error = Error;
57 fn try_from(v: Vec<T>) -> Result<Self, Self::Error> {
58 let len = v.len();
59 Tensor::from_vec(v, len, &Device::Cpu)
60 }
61}
62
63macro_rules! from_tensor {
64 ($typ:ident) => {
65 impl TryFrom<&Tensor> for $typ {
66 type Error = Error;
67
68 fn try_from(tensor: &Tensor) -> Result<Self, Self::Error> {
69 tensor.to_scalar::<$typ>()
70 }
71 }
72
73 impl TryFrom<Tensor> for $typ {
74 type Error = Error;
75
76 fn try_from(tensor: Tensor) -> Result<Self, Self::Error> {
77 $typ::try_from(&tensor)
78 }
79 }
80
81 impl TryFrom<$typ> for Tensor {
82 type Error = Error;
83
84 fn try_from(v: $typ) -> Result<Self, Self::Error> {
85 Tensor::new(v, &Device::Cpu)
86 }
87 }
88 };
89}
90
91from_tensor!(f64);
92from_tensor!(f32);
93from_tensor!(f16);
94from_tensor!(bf16);
95from_tensor!(i64);
96from_tensor!(u32);
97from_tensor!(u8);
98
99impl Tensor {
100 pub fn write_bytes<W: std::io::Write>(&self, f: &mut W) -> crate::Result<()> {
101 use byteorder::{LittleEndian, WriteBytesExt};
102
103 let vs = self.flatten_all()?;
104 match self.dtype() {
105 DType::BF16 => {
106 let vs = vs.to_vec1::<bf16>()?;
107 for &v in vs.reinterpret_cast() {
108 f.write_u16::<LittleEndian>(v)?
109 }
110 }
111 DType::F16 => {
112 let vs = vs.to_vec1::<f16>()?;
113 for &v in vs.reinterpret_cast() {
114 f.write_u16::<LittleEndian>(v)?
115 }
116 }
117 DType::F32 => {
118 for v in vs.to_vec1::<f32>()? {
120 f.write_f32::<LittleEndian>(v)?
121 }
122 }
123 DType::F64 => {
124 for v in vs.to_vec1::<f64>()? {
125 f.write_f64::<LittleEndian>(v)?
126 }
127 }
128 DType::U32 => {
129 for v in vs.to_vec1::<u32>()? {
130 f.write_u32::<LittleEndian>(v)?
131 }
132 }
133 DType::I64 => {
134 for v in vs.to_vec1::<i64>()? {
135 f.write_i64::<LittleEndian>(v)?
136 }
137 }
138 DType::U8 => {
139 let vs = vs.to_vec1::<u8>()?;
140 f.write_all(&vs)?;
141 }
142 }
143 Ok(())
144 }
145}