candle_core/
convert.rs

1//! Implement conversion traits for tensors
2use crate::{DType, Device, Error, Tensor, WithDType};
3use half::{bf16, f16, slice::HalfFloatSliceExt};
4use std::convert::TryFrom;
5
6impl<T: WithDType> TryFrom<&Tensor> for Vec<T> {
7    type Error = Error;
8    fn try_from(tensor: &Tensor) -> Result<Self, Self::Error> {
9        tensor.to_vec1::<T>()
10    }
11}
12
13impl<T: WithDType> TryFrom<&Tensor> for Vec<Vec<T>> {
14    type Error = Error;
15    fn try_from(tensor: &Tensor) -> Result<Self, Self::Error> {
16        tensor.to_vec2::<T>()
17    }
18}
19
20impl<T: WithDType> TryFrom<&Tensor> for Vec<Vec<Vec<T>>> {
21    type Error = Error;
22    fn try_from(tensor: &Tensor) -> Result<Self, Self::Error> {
23        tensor.to_vec3::<T>()
24    }
25}
26
27impl<T: WithDType> TryFrom<Tensor> for Vec<T> {
28    type Error = Error;
29    fn try_from(tensor: Tensor) -> Result<Self, Self::Error> {
30        Vec::<T>::try_from(&tensor)
31    }
32}
33
34impl<T: WithDType> TryFrom<Tensor> for Vec<Vec<T>> {
35    type Error = Error;
36    fn try_from(tensor: Tensor) -> Result<Self, Self::Error> {
37        Vec::<Vec<T>>::try_from(&tensor)
38    }
39}
40
41impl<T: WithDType> TryFrom<Tensor> for Vec<Vec<Vec<T>>> {
42    type Error = Error;
43    fn try_from(tensor: Tensor) -> Result<Self, Self::Error> {
44        Vec::<Vec<Vec<T>>>::try_from(&tensor)
45    }
46}
47
48impl<T: WithDType> TryFrom<&[T]> for Tensor {
49    type Error = Error;
50    fn try_from(v: &[T]) -> Result<Self, Self::Error> {
51        Tensor::from_slice(v, v.len(), &Device::Cpu)
52    }
53}
54
55impl<T: WithDType> TryFrom<Vec<T>> for Tensor {
56    type Error = Error;
57    fn try_from(v: Vec<T>) -> Result<Self, Self::Error> {
58        let len = v.len();
59        Tensor::from_vec(v, len, &Device::Cpu)
60    }
61}
62
63macro_rules! from_tensor {
64    ($typ:ident) => {
65        impl TryFrom<&Tensor> for $typ {
66            type Error = Error;
67
68            fn try_from(tensor: &Tensor) -> Result<Self, Self::Error> {
69                tensor.to_scalar::<$typ>()
70            }
71        }
72
73        impl TryFrom<Tensor> for $typ {
74            type Error = Error;
75
76            fn try_from(tensor: Tensor) -> Result<Self, Self::Error> {
77                $typ::try_from(&tensor)
78            }
79        }
80
81        impl TryFrom<$typ> for Tensor {
82            type Error = Error;
83
84            fn try_from(v: $typ) -> Result<Self, Self::Error> {
85                Tensor::new(v, &Device::Cpu)
86            }
87        }
88    };
89}
90
91from_tensor!(f64);
92from_tensor!(f32);
93from_tensor!(f16);
94from_tensor!(bf16);
95from_tensor!(i64);
96from_tensor!(u32);
97from_tensor!(u8);
98
99impl Tensor {
100    pub fn write_bytes<W: std::io::Write>(&self, f: &mut W) -> crate::Result<()> {
101        use byteorder::{LittleEndian, WriteBytesExt};
102
103        let vs = self.flatten_all()?;
104        match self.dtype() {
105            DType::BF16 => {
106                let vs = vs.to_vec1::<bf16>()?;
107                for &v in vs.reinterpret_cast() {
108                    f.write_u16::<LittleEndian>(v)?
109                }
110            }
111            DType::F16 => {
112                let vs = vs.to_vec1::<f16>()?;
113                for &v in vs.reinterpret_cast() {
114                    f.write_u16::<LittleEndian>(v)?
115                }
116            }
117            DType::F32 => {
118                // TODO: Avoid using a buffer when data is already on the CPU.
119                for v in vs.to_vec1::<f32>()? {
120                    f.write_f32::<LittleEndian>(v)?
121                }
122            }
123            DType::F64 => {
124                for v in vs.to_vec1::<f64>()? {
125                    f.write_f64::<LittleEndian>(v)?
126                }
127            }
128            DType::U32 => {
129                for v in vs.to_vec1::<u32>()? {
130                    f.write_u32::<LittleEndian>(v)?
131                }
132            }
133            DType::I64 => {
134                for v in vs.to_vec1::<i64>()? {
135                    f.write_i64::<LittleEndian>(v)?
136                }
137            }
138            DType::U8 => {
139                let vs = vs.to_vec1::<u8>()?;
140                f.write_all(&vs)?;
141            }
142        }
143        Ok(())
144    }
145}