candle_core/
convert.rs

1//! Implement conversion traits for tensors
2use crate::{DType, Device, Error, Tensor, WithDType};
3use half::{bf16, f16, slice::HalfFloatSliceExt};
4use std::convert::TryFrom;
5
6impl<T: WithDType> TryFrom<&Tensor> for Vec<T> {
7    type Error = Error;
8    fn try_from(tensor: &Tensor) -> Result<Self, Self::Error> {
9        tensor.to_vec1::<T>()
10    }
11}
12
13impl<T: WithDType> TryFrom<&Tensor> for Vec<Vec<T>> {
14    type Error = Error;
15    fn try_from(tensor: &Tensor) -> Result<Self, Self::Error> {
16        tensor.to_vec2::<T>()
17    }
18}
19
20impl<T: WithDType> TryFrom<&Tensor> for Vec<Vec<Vec<T>>> {
21    type Error = Error;
22    fn try_from(tensor: &Tensor) -> Result<Self, Self::Error> {
23        tensor.to_vec3::<T>()
24    }
25}
26
27impl<T: WithDType> TryFrom<Tensor> for Vec<T> {
28    type Error = Error;
29    fn try_from(tensor: Tensor) -> Result<Self, Self::Error> {
30        Vec::<T>::try_from(&tensor)
31    }
32}
33
34impl<T: WithDType> TryFrom<Tensor> for Vec<Vec<T>> {
35    type Error = Error;
36    fn try_from(tensor: Tensor) -> Result<Self, Self::Error> {
37        Vec::<Vec<T>>::try_from(&tensor)
38    }
39}
40
41impl<T: WithDType> TryFrom<Tensor> for Vec<Vec<Vec<T>>> {
42    type Error = Error;
43    fn try_from(tensor: Tensor) -> Result<Self, Self::Error> {
44        Vec::<Vec<Vec<T>>>::try_from(&tensor)
45    }
46}
47
48impl<T: WithDType> TryFrom<&[T]> for Tensor {
49    type Error = Error;
50    fn try_from(v: &[T]) -> Result<Self, Self::Error> {
51        Tensor::from_slice(v, v.len(), &Device::Cpu)
52    }
53}
54
55impl<T: WithDType> TryFrom<Vec<T>> for Tensor {
56    type Error = Error;
57    fn try_from(v: Vec<T>) -> Result<Self, Self::Error> {
58        let len = v.len();
59        Tensor::from_vec(v, len, &Device::Cpu)
60    }
61}
62
63macro_rules! from_tensor {
64    ($typ:ident) => {
65        impl TryFrom<&Tensor> for $typ {
66            type Error = Error;
67
68            fn try_from(tensor: &Tensor) -> Result<Self, Self::Error> {
69                tensor.to_scalar::<$typ>()
70            }
71        }
72
73        impl TryFrom<Tensor> for $typ {
74            type Error = Error;
75
76            fn try_from(tensor: Tensor) -> Result<Self, Self::Error> {
77                $typ::try_from(&tensor)
78            }
79        }
80
81        impl TryFrom<$typ> for Tensor {
82            type Error = Error;
83
84            fn try_from(v: $typ) -> Result<Self, Self::Error> {
85                Tensor::new(v, &Device::Cpu)
86            }
87        }
88    };
89}
90
91from_tensor!(f64);
92from_tensor!(f32);
93from_tensor!(f16);
94from_tensor!(bf16);
95from_tensor!(i64);
96from_tensor!(i32);
97from_tensor!(i16);
98from_tensor!(u32);
99from_tensor!(u8);
100
101impl Tensor {
102    pub fn write_bytes<W: std::io::Write>(&self, f: &mut W) -> crate::Result<()> {
103        use byteorder::{LittleEndian, WriteBytesExt};
104
105        let vs = self.flatten_all()?;
106        match self.dtype() {
107            DType::BF16 => {
108                let vs = vs.to_vec1::<bf16>()?;
109                for &v in vs.reinterpret_cast() {
110                    f.write_u16::<LittleEndian>(v)?
111                }
112            }
113            DType::F16 => {
114                let vs = vs.to_vec1::<f16>()?;
115                for &v in vs.reinterpret_cast() {
116                    f.write_u16::<LittleEndian>(v)?
117                }
118            }
119            DType::F32 => {
120                // TODO: Avoid using a buffer when data is already on the CPU.
121                for v in vs.to_vec1::<f32>()? {
122                    f.write_f32::<LittleEndian>(v)?
123                }
124            }
125            DType::F64 => {
126                for v in vs.to_vec1::<f64>()? {
127                    f.write_f64::<LittleEndian>(v)?
128                }
129            }
130            DType::U32 => {
131                for v in vs.to_vec1::<u32>()? {
132                    f.write_u32::<LittleEndian>(v)?
133                }
134            }
135            DType::I16 => {
136                for v in vs.to_vec1::<i16>()? {
137                    f.write_i16::<LittleEndian>(v)?
138                }
139            }
140            DType::I32 => {
141                for v in vs.to_vec1::<i32>()? {
142                    f.write_i32::<LittleEndian>(v)?
143                }
144            }
145            DType::I64 => {
146                for v in vs.to_vec1::<i64>()? {
147                    f.write_i64::<LittleEndian>(v)?
148                }
149            }
150            DType::U8 => {
151                let vs = vs.to_vec1::<u8>()?;
152                f.write_all(&vs)?;
153            }
154            DType::F8E4M3 => {
155                let vs = vs.to_vec1::<float8::F8E4M3>()?;
156                for v in vs {
157                    f.write_u8(v.to_bits())?
158                }
159            }
160            DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
161                return Err(crate::Error::UnsupportedDTypeForOp(self.dtype(), "write_bytes").bt())
162            }
163        }
164        Ok(())
165    }
166}