lumen_core/tensor/
convert.rs1use crate::{AutogradMetaT, DTypeConvert, Error, Result, TensorOrScalar, WithDType};
2use super::Tensor;
3
4impl<T: WithDType> Tensor<T> {
5 pub fn contiguous(&self) -> crate::Result<Tensor<T>> {
6 if self.is_contiguous() {
7 Ok(self.clone())
8 } else {
9 self.copy()
10 }
11 }
12
13 pub fn copy(&self) -> crate::Result<Self> {
14 let storage = self.storage_read()?.copy(self.layout());
15 let meta = T::AutogradMeta::on_copy_op(self);
16 Ok(Self::from_storage(storage, self.shape(), meta))
17 }
18
19 pub fn copy_from(&self, source: &Self) -> Result<()> {
20 if self.shape() != source.shape() {
21 Err(Error::ShapeMismatchCopyFrom { dst: self.shape().clone(), src: source.shape().clone() })?
22 }
23
24 let mut storage = self.storage_write()?;
25 for (self_storage_index, src_value) in self.layout().storage_indices().zip(source.iter()?) {
26 storage.set_unchecked(self_storage_index, src_value);
27 }
28
29 Ok(())
30 }
31
32 pub fn assign(&self, source: impl Into<TensorOrScalar<T>>) -> Result<()> {
34 match source.into() {
35 TensorOrScalar::Scalar(src) => {
36 let mut storage = self.storage_write()?;
37 for storage_index in self.layout().storage_indices() {
38 storage.set_unchecked(storage_index, src);
39 }
40 Ok(())
41 }
42 TensorOrScalar::Tensor(src) => {
43 if src.shape() != self.shape() {
44 Err(Error::ShapeMismatchCopyFrom { dst: self.shape().clone(), src: src.shape().clone() })?
45 }
46
47 let mut storage = self.storage_write()?;
48 for (self_storage_index, src_value) in self.layout().storage_indices().zip(src.iter()?) {
49 storage.set_unchecked(self_storage_index, src_value);
50 }
51
52 Ok(())
53 }
54 }
55 }
56}
57
58
59impl<From: WithDType> Tensor<From> {
60 pub fn cast<To: WithDType>(&self) -> crate::Result<Tensor<To>>
61 where
62 From: DTypeConvert<To>,
63 {
64 let storage = self.storage_read()?.copy_map(self.layout(), From::convert);
71 Ok(Tensor::<To>::from_storage(storage, self.shape(), Default::default()))
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use crate::IndexOp;
78
79 use super::*;
80
81 #[test]
82 fn test_assign() {
83 let a = Tensor::new(&[[1, 2, 3], [3, 4, 5], [4, 5, 6]]).unwrap();
84 a.index(0).unwrap().assign(100).unwrap();
85 println!("{}", a);
86 a.index((1, 1)).unwrap().assign(200).unwrap();
87 println!("{}", a);
88 a.index((1.., 1..)).unwrap().assign(999).unwrap();
89 println!("{}", a);
90 }
91
92 #[test]
93 fn test_copy_1d() {
94 let a = Tensor::new(&[1, 2, 3]).unwrap();
95 let b = a.copy().unwrap();
96 assert_eq!(a.shape(), b.shape());
97 assert_eq!(a.to_vec().unwrap(), b.to_vec().unwrap());
98 }
99
100 #[test]
101 fn test_copy_2d() {
102 let a = Tensor::new(&[[1, 2], [3, 4]]).unwrap();
103 let b = a.copy().unwrap();
104 assert_eq!(a.shape(), b.shape());
105 assert_eq!(a.to_vec().unwrap(), b.to_vec().unwrap());
106 }
107
108 #[test]
109 fn test_cast_i32_to_f32() {
110 let a = Tensor::new(&[1i32, 2, 3]).unwrap();
111 let b: Tensor<f32> = a.cast().unwrap();
112 assert_eq!(b.shape(), a.shape());
113 let expected = Tensor::new(&[1.0f32, 2.0, 3.0]).unwrap();
114 assert!(b.allclose(&expected, 1e-6, 1e-6).unwrap());
115 }
116
117 #[test]
118 fn test_cast_f64_to_i32() {
119 let a = Tensor::new(&[1.5f64, 2.7, 3.0]).unwrap();
120 let b: Tensor<i32> = a.cast().unwrap();
121 assert_eq!(b.shape(), a.shape());
122 let expected = Tensor::new(&[1i32, 2, 3]).unwrap(); assert_eq!(b.to_vec().unwrap(), expected.to_vec().unwrap());
124 }
125
126 #[test]
127 fn test_cast_2d() {
128 let a = Tensor::new(&[[1i32, 2], [3, 4]]).unwrap();
129 let b: Tensor<f64> = a.cast().unwrap();
130 assert_eq!(b.shape(), a.shape());
131 let expected = Tensor::new(&[[1.0, 2.0], [3.0, 4.0]]).unwrap();
132 assert!(b.allclose(&expected, 1e-12, 1e-12).unwrap());
133 }
134
135 #[test]
136 fn test_copy_vs_cast() {
137 let a = Tensor::new(&[1i32, 2, 3]).unwrap();
138 let b = a.copy().unwrap();
139 let c: Tensor<f32> = a.cast().unwrap();
140
141 assert_eq!(b.to_vec().unwrap(), a.to_vec().unwrap());
142
143 let expected = Tensor::new(&[1.0f32, 2.0, 3.0]).unwrap();
144 assert!(c.allclose(&expected, 1e-6, 1e-6).unwrap());
145 }
146}