Skip to main content

lumen_core/tensor/
convert.rs

1use crate::{AutogradMetaT, DTypeConvert, Error, Result, TensorOrScalar, WithDType};
2use super::Tensor;
3
4impl<T: WithDType> Tensor<T> {
5    pub fn contiguous(&self) -> crate::Result<Tensor<T>> {
6        if self.is_contiguous() {
7            Ok(self.clone())
8        } else {
9            self.copy()
10        }
11    }
12
13    pub fn copy(&self) -> crate::Result<Self> {
14        let storage = self.storage_read()?.copy(self.layout());
15        let meta = T::AutogradMeta::on_copy_op(self);
16        Ok(Self::from_storage(storage, self.shape(), meta))
17    }
18
19    pub fn copy_from(&self, source: &Self) -> Result<()> {
20        if self.shape() != source.shape() {
21            Err(Error::ShapeMismatchCopyFrom { dst: self.shape().clone(), src: source.shape().clone() })?
22        }
23
24        let mut storage = self.storage_write()?;
25        for (self_storage_index, src_value) in self.layout().storage_indices().zip(source.iter()?) {
26            storage.set_unchecked(self_storage_index, src_value);
27        }
28
29        Ok(())
30    }
31
32    // no grad record
33    pub fn assign(&self, source: impl Into<TensorOrScalar<T>>) -> Result<()> {
34        match source.into() {
35            TensorOrScalar::Scalar(src) => {
36                let mut storage = self.storage_write()?;
37                for storage_index in self.layout().storage_indices() {
38                    storage.set_unchecked(storage_index, src);
39                }
40                Ok(())
41            }
42            TensorOrScalar::Tensor(src) => {
43                if src.shape() != self.shape() {
44                    Err(Error::ShapeMismatchCopyFrom { dst: self.shape().clone(), src: src.shape().clone() })?
45                }
46        
47                let mut storage = self.storage_write()?;
48                for (self_storage_index, src_value) in self.layout().storage_indices().zip(src.iter()?) {
49                    storage.set_unchecked(self_storage_index, src_value);
50                }
51        
52                Ok(())
53            }
54        }
55    }
56}
57
58
59impl<From: WithDType> Tensor<From> {
60    pub fn cast<To: WithDType>(&self) -> crate::Result<Tensor<To>> 
61    where
62        From: DTypeConvert<To>,
63    {
64        // if TypeId::of::<From>() == TypeId::of::<To>() {
65        //     let self_any = self as &dyn Any;            
66        //     if let Some(same_tensor) = self_any.downcast_ref::<Tensor<To>>() {
67        //         return same_tensor.clone();
68        //     }
69        // }
70        let storage = self.storage_read()?.copy_map(self.layout(), From::convert);
71        Ok(Tensor::<To>::from_storage(storage, self.shape(), Default::default()))
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use crate::IndexOp;
78
79    use super::*;
80
81    #[test]
82    fn test_assign() {
83        let a = Tensor::new(&[[1, 2, 3], [3, 4, 5], [4, 5, 6]]).unwrap();
84        a.index(0).unwrap().assign(100).unwrap();
85        println!("{}", a);
86        a.index((1, 1)).unwrap().assign(200).unwrap();
87        println!("{}", a);
88        a.index((1.., 1..)).unwrap().assign(999).unwrap();
89        println!("{}", a);
90    }
91
92    #[test]
93    fn test_copy_1d() {
94        let a = Tensor::new(&[1, 2, 3]).unwrap();
95        let b = a.copy().unwrap();
96        assert_eq!(a.shape(), b.shape());
97        assert_eq!(a.to_vec().unwrap(), b.to_vec().unwrap());
98    }
99
100    #[test]
101    fn test_copy_2d() {
102        let a = Tensor::new(&[[1, 2], [3, 4]]).unwrap();
103        let b = a.copy().unwrap();
104        assert_eq!(a.shape(), b.shape());
105        assert_eq!(a.to_vec().unwrap(), b.to_vec().unwrap());
106    }
107
108    #[test]
109    fn test_cast_i32_to_f32() {
110        let a = Tensor::new(&[1i32, 2, 3]).unwrap();
111        let b: Tensor<f32> = a.cast().unwrap();
112        assert_eq!(b.shape(), a.shape());
113        let expected = Tensor::new(&[1.0f32, 2.0, 3.0]).unwrap();
114        assert!(b.allclose(&expected, 1e-6, 1e-6).unwrap());
115    }
116
117    #[test]
118    fn test_cast_f64_to_i32() {
119        let a = Tensor::new(&[1.5f64, 2.7, 3.0]).unwrap();
120        let b: Tensor<i32> = a.cast().unwrap();
121        assert_eq!(b.shape(), a.shape());
122        let expected = Tensor::new(&[1i32, 2, 3]).unwrap(); // cast truncates
123        assert_eq!(b.to_vec().unwrap(), expected.to_vec().unwrap());
124    }
125
126    #[test]
127    fn test_cast_2d() {
128        let a = Tensor::new(&[[1i32, 2], [3, 4]]).unwrap();
129        let b: Tensor<f64> = a.cast().unwrap();
130        assert_eq!(b.shape(), a.shape());
131        let expected = Tensor::new(&[[1.0, 2.0], [3.0, 4.0]]).unwrap();
132        assert!(b.allclose(&expected, 1e-12, 1e-12).unwrap());
133    }
134
135    #[test]
136    fn test_copy_vs_cast() {
137        let a = Tensor::new(&[1i32, 2, 3]).unwrap();
138        let b = a.copy().unwrap();
139        let c: Tensor<f32> = a.cast().unwrap();
140
141        assert_eq!(b.to_vec().unwrap(), a.to_vec().unwrap());
142
143        let expected = Tensor::new(&[1.0f32, 2.0, 3.0]).unwrap();
144        assert!(c.allclose(&expected, 1e-6, 1e-6).unwrap());
145    }
146}