1use crate::prelude_dev::*;
2
3impl<R, T, B, D> TensorAny<R, T, B, D>
7where
8 D: DimAPI,
9 B: DeviceAPI<T>,
10 R: DataAPI<Data = B::Raw>,
11{
12 pub fn view(&self) -> TensorView<'_, T, B, D> {
14 let layout = self.layout().clone();
15 let data = self.data().as_ref();
16 let storage = Storage::new(data, self.device().clone());
17 unsafe { TensorBase::new_unchecked(storage, layout) }
18 }
19
20 pub fn view_mut(&mut self) -> TensorMut<'_, T, B, D>
22 where
23 R: DataMutAPI,
24 {
25 let device = self.device().clone();
26 let layout = self.layout().clone();
27 let data = self.data_mut().as_mut();
28 let storage = Storage::new(data, device);
29 unsafe { TensorBase::new_unchecked(storage, layout) }
30 }
31
32 pub fn into_cow<'a>(self) -> TensorCow<'a, T, B, D>
34 where
35 R: DataIntoCowAPI<'a>,
36 {
37 let (storage, layout) = self.into_raw_parts();
38 let (data, device) = storage.into_raw_parts();
39 let storage = Storage::new(data.into_cow(), device);
40 unsafe { TensorBase::new_unchecked(storage, layout) }
41 }
42
43 pub fn into_owned_keep_layout(self) -> Tensor<T, B, D>
55 where
56 R::Data: Clone,
57 R: DataCloneAPI,
58 {
59 let (storage, layout) = self.into_raw_parts();
60 let (data, device) = storage.into_raw_parts();
61 let storage = Storage::new(data.into_owned(), device);
62 unsafe { TensorBase::new_unchecked(storage, layout) }
63 }
64
65 pub fn into_shared_keep_layout(self) -> TensorArc<T, B, D>
77 where
78 R::Data: Clone,
79 R: DataCloneAPI,
80 {
81 let (storage, layout) = self.into_raw_parts();
82 let (data, device) = storage.into_raw_parts();
83 let storage = Storage::new(data.into_shared(), device);
84 unsafe { TensorBase::new_unchecked(storage, layout) }
85 }
86}
87
88impl<R, T, B, D> TensorAny<R, T, B, D>
89where
90 R: DataCloneAPI<Data = B::Raw>,
91 R::Data: Clone,
92 D: DimAPI,
93 T: Clone,
94 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
95{
96 pub fn into_owned(self) -> Tensor<T, B, D> {
97 let (idx_min, idx_max) = self.layout().bounds_index().unwrap();
98 if idx_min == 0 && idx_max == self.storage().len() && idx_max == self.layout().size() {
99 return self.into_owned_keep_layout();
100 } else {
101 return asarray((&self, TensorIterOrder::K));
102 }
103 }
104
105 pub fn into_shared(self) -> TensorArc<T, B, D> {
106 let (idx_min, idx_max) = self.layout().bounds_index().unwrap();
107 if idx_min == 0 && idx_max == self.storage().len() && idx_max == self.layout().size() {
108 return self.into_shared_keep_layout();
109 } else {
110 return asarray((&self, TensorIterOrder::K)).into_shared();
111 }
112 }
113
114 pub fn to_owned(&self) -> Tensor<T, B, D> {
115 self.view().into_owned()
116 }
117}
118
119impl<T, B, D> Clone for Tensor<T, B, D>
120where
121 T: Clone,
122 D: DimAPI,
123 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
124 B::Raw: Clone,
125{
126 fn clone(&self) -> Self {
127 self.to_owned()
128 }
129}
130
131impl<T, B, D> Clone for TensorCow<'_, T, B, D>
132where
133 T: Clone,
134 D: DimAPI,
135 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
136 B::Raw: Clone,
137{
138 fn clone(&self) -> Self {
139 let tsr_owned = self.to_owned();
140 let (storage, layout) = tsr_owned.into_raw_parts();
141 let (data, device) = storage.into_raw_parts();
142 let data = data.into_cow();
143 let storage = Storage::new(data, device);
144 unsafe { TensorBase::new_unchecked(storage, layout) }
145 }
146}
147
148impl<R, T, B, D> TensorAny<R, T, B, D>
149where
150 R: DataAPI<Data = B::Raw> + DataForceMutAPI<B::Raw>,
151 B: DeviceAPI<T>,
152 D: DimAPI,
153{
154 pub unsafe fn force_mut(&self) -> TensorMut<'_, T, B, D> {
159 let layout = self.layout().clone();
160 let data = self.data().force_mut();
161 let storage = Storage::new(data, self.device().clone());
162 TensorBase::new_unchecked(storage, layout)
163 }
164}
165
166impl<R, T, B, D> TensorAny<R, T, B, D>
171where
172 R: DataAPI<Data = B::Raw>,
173 T: Clone,
174 D: DimAPI,
175 B: DeviceAPI<T, Raw = Vec<T>> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, Ix1>,
176{
177 pub fn to_raw_f(&self) -> Result<Vec<T>> {
178 rstsr_assert_eq!(self.ndim(), 1, InvalidLayout, "to_vec currently only support 1-D tensor")?;
179 let device = self.device();
180 let layout = self.layout().to_dim::<Ix1>()?;
181 let size = layout.size();
182 let mut new_storage = unsafe { device.empty_impl(size)? };
183 device.assign(new_storage.raw_mut(), &[size].c(), self.raw(), &layout)?;
184 let (data, _) = new_storage.into_raw_parts();
185 Ok(data.into_raw())
186 }
187
188 pub fn to_vec(&self) -> Vec<T> {
189 self.to_raw_f().unwrap()
190 }
191}
192
193impl<T, B, D> Tensor<T, B, D>
194where
195 T: Clone,
196 D: DimAPI,
197 B: DeviceAPI<T, Raw = Vec<T>> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, Ix1>,
198{
199 pub fn into_vec_f(self) -> Result<Vec<T>> {
200 rstsr_assert_eq!(self.ndim(), 1, InvalidLayout, "to_vec currently only support 1-D tensor")?;
201 let layout = self.layout();
202 let (idx_min, idx_max) = layout.bounds_index()?;
203 if idx_min == 0 && idx_max == self.storage().len() && idx_max == layout.size() && layout.stride()[0] > 0 {
204 let (storage, _) = self.into_raw_parts();
205 let (data, _) = storage.into_raw_parts();
206 return Ok(data.into_raw());
207 } else {
208 return self.to_raw_f();
209 }
210 }
211
212 pub fn into_vec(self) -> Vec<T> {
213 self.into_vec_f().unwrap()
214 }
215}
216
217impl<R, T, B, D> TensorAny<R, T, B, D>
222where
223 R: DataCloneAPI<Data = B::Raw>,
224 B::Raw: Clone,
225 T: Clone,
226 D: DimAPI,
227 B: DeviceAPI<T>,
228{
229 pub fn to_scalar_f(&self) -> Result<T> {
230 let layout = self.layout();
231 rstsr_assert_eq!(layout.size(), 1, InvalidLayout)?;
232 let storage = self.storage();
233 let vec = storage.to_cpu_vec()?;
234 Ok(vec[0].clone())
235 }
236
237 pub fn to_scalar(&self) -> T {
238 self.to_scalar_f().unwrap()
239 }
240}
241
242impl<R, T, B, D> TensorAny<R, T, B, D>
247where
248 R: DataAPI<Data = B::Raw>,
249 D: DimAPI,
250 B: DeviceAPI<T, Raw = Vec<T>>,
251{
252 pub fn as_ptr(&self) -> *const T {
253 unsafe { self.raw().as_ptr().add(self.layout().offset()) }
254 }
255
256 pub fn as_mut_ptr(&mut self) -> *mut T
257 where
258 R: DataMutAPI,
259 {
260 unsafe { self.raw_mut().as_mut_ptr().add(self.layout().offset()) }
261 }
262}
263
264pub trait TensorViewAPI<T, B, D>
269where
270 D: DimAPI,
271 B: DeviceAPI<T>,
272{
273 fn view(&self) -> TensorView<'_, T, B, D>;
275}
276
277impl<R, T, B, D> TensorViewAPI<T, B, D> for TensorAny<R, T, B, D>
278where
279 D: DimAPI,
280 R: DataAPI<Data = B::Raw>,
281 B: DeviceAPI<T>,
282{
283 fn view(&self) -> TensorView<'_, T, B, D> {
284 let data = self.data().as_ref();
285 let storage = Storage::new(data, self.device().clone());
286 let layout = self.layout().clone();
287 unsafe { TensorBase::new_unchecked(storage, layout) }
288 }
289}
290
291impl<R, T, B, D> TensorViewAPI<T, B, D> for &TensorAny<R, T, B, D>
292where
293 D: DimAPI,
294 R: DataAPI<Data = B::Raw>,
295 B: DeviceAPI<T>,
296{
297 fn view(&self) -> TensorView<'_, T, B, D> {
298 (*self).view()
299 }
300}
301
302pub trait TensorViewMutAPI<T, B, D>
303where
304 D: DimAPI,
305 B: DeviceAPI<T>,
306{
307 fn view_mut(&mut self) -> TensorMut<'_, T, B, D>;
309}
310
311impl<R, T, B, D> TensorViewMutAPI<T, B, D> for TensorAny<R, T, B, D>
312where
313 D: DimAPI,
314 R: DataMutAPI<Data = B::Raw>,
315 B: DeviceAPI<T>,
316{
317 fn view_mut(&mut self) -> TensorMut<'_, T, B, D> {
318 let device = self.device().clone();
319 let layout = self.layout().clone();
320 let data = self.data_mut().as_mut();
321 let storage = Storage::new(data, device);
322 unsafe { TensorBase::new_unchecked(storage, layout) }
323 }
324}
325
326impl<R, T, B, D> TensorViewMutAPI<T, B, D> for &mut TensorAny<R, T, B, D>
327where
328 D: DimAPI,
329 R: DataMutAPI<Data = B::Raw>,
330 B: DeviceAPI<T>,
331{
332 fn view_mut(&mut self) -> TensorMut<'_, T, B, D> {
333 (*self).view_mut()
334 }
335}
336
337pub trait TensorIntoOwnedAPI<T, B, D>
338where
339 D: DimAPI,
340 B: DeviceAPI<T>,
341{
342 fn into_owned(self) -> Tensor<T, B, D>;
348}
349
350impl<R, T, B, D> TensorIntoOwnedAPI<T, B, D> for TensorAny<R, T, B, D>
351where
352 R: DataCloneAPI<Data = B::Raw>,
353 B::Raw: Clone,
354 T: Clone,
355 D: DimAPI,
356 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
357{
358 fn into_owned(self) -> Tensor<T, B, D> {
359 TensorAny::into_owned(self)
360 }
361}
362
363pub trait TensorRefAPI {}
368impl<R, T, B, D> TensorRefAPI for &TensorAny<R, T, B, D>
369where
370 D: DimAPI,
371 R: DataAPI<Data = B::Raw>,
372 B: DeviceAPI<T>,
373 Self: TensorViewAPI<T, B, D>,
374{
375}
376impl<T, B, D> TensorRefAPI for TensorView<'_, T, B, D>
377where
378 D: DimAPI,
379 B: DeviceAPI<T>,
380 Self: TensorViewAPI<T, B, D>,
381{
382}
383
384pub trait TensorRefMutAPI {}
385impl<R, T, B, D> TensorRefMutAPI for &mut TensorAny<R, T, B, D>
386where
387 D: DimAPI,
388 R: DataMutAPI<Data = B::Raw>,
389 B: DeviceAPI<T>,
390 Self: TensorViewMutAPI<T, B, D>,
391{
392}
393impl<T, B, D> TensorRefMutAPI for TensorMut<'_, T, B, D>
394where
395 D: DimAPI,
396 B: DeviceAPI<T>,
397 Self: TensorViewMutAPI<T, B, D>,
398{
399}
400
401#[cfg(test)]
404mod test {
405 use super::*;
406
407 #[test]
408 fn test_into_cow() {
409 let mut a = arange(3);
410 let ptr_a = a.raw().as_ptr();
411
412 let a_mut = a.view_mut();
413 let a_cow = a_mut.into_cow();
414 println!("{a_cow:?}");
415
416 let a_ref = a.view();
417 let a_cow = a_ref.into_cow();
418 println!("{a_cow:?}");
419
420 let a_cow = a.into_cow();
421 println!("{a_cow:?}");
422 let ptr_a_cow = a_cow.raw().as_ptr();
423 assert_eq!(ptr_a, ptr_a_cow);
424 }
425
426 #[test]
427 #[ignore]
428 fn test_force_mut() {
429 let n = 4096;
430 let a = linspace((0.0, 1.0, n * n)).into_shape((n, n));
431 for _ in 0..10 {
432 let time = std::time::Instant::now();
433 for i in 0..n {
434 let a_view = a.slice(i);
435 let mut a_mut = unsafe { a_view.force_mut() };
436 a_mut *= i as f64 / 2048.0;
437 }
438 println!("Elapsed time {:?}", time.elapsed());
439 }
440 println!("{a:16.10}");
441 }
442
443 #[test]
444 #[ignore]
445 #[cfg(feature = "rayon")]
446 fn test_force_mut_par() {
447 use rayon::prelude::*;
448 let n = 4096;
449 let a = linspace((0.0, 1.0, n * n)).into_shape((n, n));
450 for _ in 0..10 {
451 let time = std::time::Instant::now();
452 (0..n).into_par_iter().for_each(|i| {
453 let a_view = a.slice(i);
454 let mut a_mut = unsafe { a_view.force_mut() };
455 a_mut *= i as f64 / 2048.0;
456 });
457 println!("Elapsed time {:?}", time.elapsed());
458 }
459 println!("{a:16.10}");
460 }
461}