1use crate::prelude_dev::*;
2use core::mem::transmute;
3
4impl<R, T, B, D> TensorAny<R, T, B, D>
9where
10 R: DataAPI<Data = B::Raw>,
11 D: DimAPI,
12 B: DeviceAPI<T>,
13{
14 pub fn map_fnmut_f<'f, TOut>(&self, mut f: impl FnMut(&T) -> TOut + 'f) -> Result<Tensor<TOut, B, D>>
17 where
18 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
19 B: DeviceOp_MutA_RefB_API<TOut, T, D, dyn FnMut(&mut MaybeUninit<TOut>, &T) + 'f>,
20 {
21 let la = self.layout();
22 let lc = layout_for_array_copy(la, TensorIterOrder::default())?;
23 let device = self.device();
24 let mut storage_c = device.uninit_impl(lc.bounds_index()?.1)?;
25 let mut f_inner = move |c: &mut MaybeUninit<TOut>, a: &T| {
26 c.write(f(a));
27 };
28 device.op_muta_refb_func(storage_c.raw_mut(), &lc, self.raw(), la, &mut f_inner)?;
29 let storage_c = unsafe { B::assume_init_impl(storage_c) }?;
30 return Tensor::new_f(storage_c, lc);
31 }
32
33 pub fn map_fnmut<'f, TOut>(&self, f: impl FnMut(&T) -> TOut + 'f) -> Tensor<TOut, B, D>
36 where
37 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
38 B: DeviceOp_MutA_RefB_API<TOut, T, D, dyn FnMut(&mut MaybeUninit<TOut>, &T) + 'f>,
39 {
40 self.map_fnmut_f(f).rstsr_unwrap()
41 }
42
43 pub fn mapv_fnmut_f<'f, TOut>(&self, mut f: impl FnMut(T) -> TOut + 'f) -> Result<Tensor<TOut, B, D>>
46 where
47 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
48 T: Clone,
49 B: DeviceOp_MutA_RefB_API<TOut, T, D, dyn FnMut(&mut MaybeUninit<TOut>, &T) + 'f>,
50 {
51 self.map_fnmut_f(move |x| f(x.clone()))
52 }
53
54 pub fn mapv_fnmut<'f, TOut>(&self, mut f: impl FnMut(T) -> TOut + 'f) -> Tensor<TOut, B, D>
57 where
58 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
59 T: Clone,
60 B: DeviceOp_MutA_RefB_API<TOut, T, D, dyn FnMut(&mut MaybeUninit<TOut>, &T) + 'f>,
61 {
62 self.map_fnmut_f(move |x| f(x.clone())).rstsr_unwrap()
63 }
64
65 pub fn mapi_fnmut_f<'f>(&mut self, mut f: impl FnMut(&mut T) + 'f) -> Result<()>
68 where
69 R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
70 B: DeviceOp_MutA_API<T, D, dyn FnMut(&mut MaybeUninit<T>) + 'f>,
71 {
72 let (la, _) = greedy_layout(self.layout(), false);
73 let device = self.device().clone();
74 let self_raw_mut = unsafe {
75 transmute::<&mut <B as DeviceRawAPI<T>>::Raw, &mut <B as DeviceRawAPI<MaybeUninit<T>>>::Raw>(self.raw_mut())
76 };
77 let mut f_inner = move |x: &mut MaybeUninit<T>| {
78 let x_ref = unsafe { x.assume_init_mut() };
79 f(x_ref);
80 };
81 device.op_muta_func(self_raw_mut, &la, &mut f_inner)
82 }
83
84 pub fn mapi_fnmut<'f>(&mut self, f: impl FnMut(&mut T) + 'f)
87 where
88 R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
89 B: DeviceOp_MutA_API<T, D, dyn FnMut(&mut MaybeUninit<T>) + 'f>,
90 {
91 self.mapi_fnmut_f(f).rstsr_unwrap()
92 }
93
94 pub fn mapvi_fnmut_f<'f>(&mut self, mut f: impl FnMut(T) -> T + 'f) -> Result<()>
97 where
98 R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
99 T: Clone,
100 B: DeviceOp_MutA_API<T, D, dyn FnMut(&mut MaybeUninit<T>) + 'f>,
101 {
102 self.mapi_fnmut_f(move |x| *x = f(x.clone()))
103 }
104
105 pub fn mapvi_fnmut<'f>(&mut self, f: impl FnMut(T) -> T + 'f)
108 where
109 R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
110 T: Clone,
111 B: DeviceOp_MutA_API<T, D, dyn FnMut(&mut MaybeUninit<T>) + 'f>,
112 {
113 self.mapvi_fnmut_f(f).rstsr_unwrap()
114 }
115}
116
117impl<R, T, B, D> TensorAny<R, T, B, D>
120where
121 R: DataAPI<Data = B::Raw>,
122 D: DimAPI,
123 B: DeviceAPI<T>,
124 T: Clone,
125{
126 pub fn mapb_fnmut_f<'f, R2, T2, D2, DOut, TOut>(
127 &self,
128 other: &TensorAny<R2, T2, B, D2>,
129 mut f: impl FnMut(&T, &T2) -> TOut + 'f,
130 ) -> Result<Tensor<TOut, B, DOut>>
131 where
132 R2: DataAPI<Data = <B as DeviceRawAPI<T2>>::Raw>,
133 D2: DimAPI,
134 DOut: DimAPI,
135 D: DimMaxAPI<D2, Max = DOut>,
136 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
137 B: DeviceOp_MutC_RefA_RefB_API<T, T2, TOut, DOut, dyn FnMut(&mut MaybeUninit<TOut>, &T, &T2) + 'f>,
138 {
139 let a = self.view();
141 let b = other.view();
142 rstsr_assert!(a.device().same_device(b.device()), DeviceMismatch)?;
144 let la = a.layout();
145 let lb = b.layout();
146 let default_order = a.device().default_order();
147 let (la_b, lb_b) = broadcast_layout(la, lb, default_order)?;
148 let lc_from_a = layout_for_array_copy(&la_b, TensorIterOrder::default())?;
150 let lc_from_b = layout_for_array_copy(&lb_b, TensorIterOrder::default())?;
151 let lc = if lc_from_a == lc_from_b {
152 lc_from_a
153 } else {
154 match self.device().default_order() {
155 RowMajor => la_b.shape().c(),
156 ColMajor => la_b.shape().f(),
157 }
158 };
159 let device = self.device();
160 let mut storage_c = device.uninit_impl(lc.bounds_index()?.1)?;
161 let mut f_inner = move |c: &mut MaybeUninit<TOut>, a: &T, b: &T2| {
162 c.write(f(a, b));
163 };
164 device.op_mutc_refa_refb_func(storage_c.raw_mut(), &lc, self.raw(), &la_b, other.raw(), &lb_b, &mut f_inner)?;
165 let storage_c = unsafe { B::assume_init_impl(storage_c) }?;
166 Tensor::new_f(storage_c, lc)
167 }
168
169 pub fn mapb_fnmut<'f, R2, T2, D2, DOut, TOut>(
170 &self,
171 other: &TensorAny<R2, T2, B, D2>,
172 f: impl FnMut(&T, &T2) -> TOut + 'f,
173 ) -> Tensor<TOut, B, DOut>
174 where
175 R2: DataAPI<Data = <B as DeviceRawAPI<T2>>::Raw>,
176 D2: DimAPI,
177 DOut: DimAPI,
178 D: DimMaxAPI<D2, Max = DOut>,
179 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
180 B: DeviceOp_MutC_RefA_RefB_API<T, T2, TOut, DOut, dyn FnMut(&mut MaybeUninit<TOut>, &T, &T2) + 'f>,
181 {
182 self.mapb_fnmut_f(other, f).rstsr_unwrap()
183 }
184
185 pub fn mapvb_fnmut_f<'f, R2, T2, D2, DOut, TOut>(
186 &self,
187 other: &TensorAny<R2, T2, B, D2>,
188 mut f: impl FnMut(T, T2) -> TOut + 'f,
189 ) -> Result<Tensor<TOut, B, DOut>>
190 where
191 R2: DataAPI<Data = <B as DeviceRawAPI<T2>>::Raw>,
192 D2: DimAPI,
193 DOut: DimAPI,
194 D: DimMaxAPI<D2, Max = DOut>,
195 T: Clone,
196 T2: Clone,
197 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
198 B: DeviceOp_MutC_RefA_RefB_API<T, T2, TOut, DOut, dyn FnMut(&mut MaybeUninit<TOut>, &T, &T2) + 'f>,
199 {
200 self.mapb_fnmut_f(other, move |x, y| f(x.clone(), y.clone()))
201 }
202
203 pub fn mapvb_fnmut<'f, R2, T2, D2, DOut, TOut>(
204 &self,
205 other: &TensorAny<R2, T2, B, D2>,
206 mut f: impl FnMut(T, T2) -> TOut + 'f,
207 ) -> Tensor<TOut, B, DOut>
208 where
209 R2: DataAPI<Data = <B as DeviceRawAPI<T2>>::Raw>,
210 D2: DimAPI,
211 DOut: DimAPI,
212 D: DimMaxAPI<D2, Max = DOut>,
213 T: Clone,
214 T2: Clone,
215 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
216 B: DeviceOp_MutC_RefA_RefB_API<T, T2, TOut, DOut, dyn FnMut(&mut MaybeUninit<TOut>, &T, &T2) + 'f>,
217 {
218 self.mapb_fnmut_f(other, move |x, y| f(x.clone(), y.clone())).rstsr_unwrap()
219 }
220}
221
222impl<R, T, B, D> TensorAny<R, T, B, D>
229where
230 R: DataAPI<Data = B::Raw>,
231 D: DimAPI,
232 B: DeviceAPI<T>,
233{
234 pub fn map_f<'f, TOut>(&self, f: impl Fn(&T) -> TOut + Send + Sync + 'f) -> Result<Tensor<TOut, B, D>>
237 where
238 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
239 B: DeviceOp_MutA_RefB_API<TOut, T, D, dyn Fn(&mut MaybeUninit<TOut>, &T) + Send + Sync + 'f>,
240 {
241 let la = self.layout();
242 let lc = layout_for_array_copy(la, TensorIterOrder::default())?;
243 let device = self.device();
244 let mut storage_c = device.uninit_impl(lc.bounds_index()?.1)?;
245 let mut f_inner = move |c: &mut MaybeUninit<TOut>, a: &T| {
246 c.write(f(a));
247 };
248 device.op_muta_refb_func(storage_c.raw_mut(), &lc, self.raw(), la, &mut f_inner)?;
249 let storage_c = unsafe { B::assume_init_impl(storage_c) }?;
250 return Tensor::new_f(storage_c, lc);
251 }
252
253 pub fn map<'f, TOut>(&self, f: impl Fn(&T) -> TOut + Send + Sync + 'f) -> Tensor<TOut, B, D>
256 where
257 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
258 B: DeviceOp_MutA_RefB_API<TOut, T, D, dyn Fn(&mut MaybeUninit<TOut>, &T) + Send + Sync + 'f>,
259 {
260 self.map_f(f).rstsr_unwrap()
261 }
262
263 pub fn mapv_f<'f, TOut>(&self, f: impl Fn(T) -> TOut + Send + Sync + 'f) -> Result<Tensor<TOut, B, D>>
266 where
267 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
268 T: Clone,
269 B: DeviceOp_MutA_RefB_API<TOut, T, D, dyn Fn(&mut MaybeUninit<TOut>, &T) + Send + Sync + 'f>,
270 {
271 self.map_f(move |x| f(x.clone()))
272 }
273
274 pub fn mapv<'f, TOut>(&self, f: impl Fn(T) -> TOut + Send + Sync + 'f) -> Tensor<TOut, B, D>
277 where
278 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
279 T: Clone,
280 B: DeviceOp_MutA_RefB_API<TOut, T, D, dyn Fn(&mut MaybeUninit<TOut>, &T) + Send + Sync + 'f>,
281 {
282 self.map_f(move |x| f(x.clone())).rstsr_unwrap()
283 }
284
285 pub fn mapi_f<'f>(&mut self, f: impl Fn(&mut T) + Send + Sync + 'f) -> Result<()>
288 where
289 R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
290 B: DeviceOp_MutA_API<T, D, dyn Fn(&mut MaybeUninit<T>) + Send + Sync + 'f>,
291 {
292 let (la, _) = greedy_layout(self.layout(), false);
293 let device = self.device().clone();
294 let self_raw_mut = unsafe {
295 transmute::<&mut <B as DeviceRawAPI<T>>::Raw, &mut <B as DeviceRawAPI<MaybeUninit<T>>>::Raw>(self.raw_mut())
296 };
297 let mut f_inner = move |x: &mut MaybeUninit<T>| {
298 let x_ref = unsafe { x.assume_init_mut() };
299 f(x_ref);
300 };
301 device.op_muta_func(self_raw_mut, &la, &mut f_inner)
302 }
303
304 pub fn mapi<'f>(&mut self, f: impl Fn(&mut T) + Send + Sync + 'f)
307 where
308 R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
309 B: DeviceOp_MutA_API<T, D, dyn Fn(&mut MaybeUninit<T>) + Send + Sync + 'f>,
310 {
311 self.mapi_f(f).rstsr_unwrap()
312 }
313
314 pub fn mapvi_f<'f>(&mut self, f: impl Fn(T) -> T + Send + Sync + 'f) -> Result<()>
317 where
318 R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
319 T: Clone,
320 B: DeviceOp_MutA_API<T, D, dyn Fn(&mut MaybeUninit<T>) + Send + Sync + 'f>,
321 {
322 self.mapi_f(move |x| *x = f(x.clone()))
323 }
324
325 pub fn mapvi<'f>(&mut self, f: impl Fn(T) -> T + Send + Sync + 'f)
328 where
329 R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
330 T: Clone,
331 B: DeviceOp_MutA_API<T, D, dyn Fn(&mut MaybeUninit<T>) + Send + Sync + 'f>,
332 {
333 self.mapvi_f(f).rstsr_unwrap()
334 }
335}
336
337impl<R, T, B, D> TensorAny<R, T, B, D>
340where
341 R: DataAPI<Data = B::Raw>,
342 D: DimAPI,
343 B: DeviceAPI<T>,
344 T: Clone,
345{
346 pub fn mapb_f<'f, R2, T2, D2, DOut, TOut>(
347 &self,
348 other: &TensorAny<R2, T2, B, D2>,
349 f: impl Fn(&T, &T2) -> TOut + Send + Sync + 'f,
350 ) -> Result<Tensor<TOut, B, DOut>>
351 where
352 R2: DataAPI<Data = <B as DeviceRawAPI<T2>>::Raw>,
353 D2: DimAPI,
354 DOut: DimAPI,
355 D: DimMaxAPI<D2, Max = DOut>,
356 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
357 B: DeviceOp_MutC_RefA_RefB_API<T, T2, TOut, DOut, dyn Fn(&mut MaybeUninit<TOut>, &T, &T2) + Send + Sync + 'f>,
358 {
359 let a = self.view();
361 let b = other.view();
362 rstsr_assert!(a.device().same_device(b.device()), DeviceMismatch)?;
364 let la = a.layout();
365 let lb = b.layout();
366 let default_order = a.device().default_order();
367 let (la_b, lb_b) = broadcast_layout(la, lb, default_order)?;
368 let lc_from_a = layout_for_array_copy(&la_b, TensorIterOrder::default())?;
370 let lc_from_b = layout_for_array_copy(&lb_b, TensorIterOrder::default())?;
371 let lc = if lc_from_a == lc_from_b {
372 lc_from_a
373 } else {
374 match self.device().default_order() {
375 RowMajor => la_b.shape().c(),
376 ColMajor => la_b.shape().f(),
377 }
378 };
379 let device = self.device();
380 let mut storage_c = device.uninit_impl(lc.bounds_index()?.1)?;
381 let mut f_inner = move |c: &mut MaybeUninit<TOut>, a: &T, b: &T2| {
382 c.write(f(a, b));
383 };
384 device.op_mutc_refa_refb_func(storage_c.raw_mut(), &lc, self.raw(), &la_b, other.raw(), &lb_b, &mut f_inner)?;
385 let storage_c = unsafe { B::assume_init_impl(storage_c) }?;
386 Tensor::new_f(storage_c, lc)
387 }
388
389 pub fn mapb<'f, R2, T2, D2, DOut, TOut>(
390 &self,
391 other: &TensorAny<R2, T2, B, D2>,
392 f: impl Fn(&T, &T2) -> TOut + Send + Sync + 'f,
393 ) -> Tensor<TOut, B, DOut>
394 where
395 R2: DataAPI<Data = <B as DeviceRawAPI<T2>>::Raw>,
396 D2: DimAPI,
397 DOut: DimAPI,
398 D: DimMaxAPI<D2, Max = DOut>,
399 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
400 B: DeviceOp_MutC_RefA_RefB_API<T, T2, TOut, DOut, dyn Fn(&mut MaybeUninit<TOut>, &T, &T2) + Send + Sync + 'f>,
401 {
402 self.mapb_f(other, f).rstsr_unwrap()
403 }
404
405 pub fn mapvb_f<'f, R2, T2, D2, DOut, TOut>(
406 &self,
407 other: &TensorAny<R2, T2, B, D2>,
408 f: impl Fn(T, T2) -> TOut + Send + Sync + 'f,
409 ) -> Result<Tensor<TOut, B, DOut>>
410 where
411 R2: DataAPI<Data = <B as DeviceRawAPI<T2>>::Raw>,
412 D2: DimAPI,
413 DOut: DimAPI,
414 D: DimMaxAPI<D2, Max = DOut>,
415 T: Clone,
416 T2: Clone,
417 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
418 B: DeviceOp_MutC_RefA_RefB_API<T, T2, TOut, DOut, dyn Fn(&mut MaybeUninit<TOut>, &T, &T2) + Send + Sync + 'f>,
419 {
420 self.mapb_f(other, move |x, y| f(x.clone(), y.clone()))
421 }
422
423 pub fn mapvb<'f, R2, T2, D2, DOut, TOut>(
424 &self,
425 other: &TensorAny<R2, T2, B, D2>,
426 f: impl Fn(T, T2) -> TOut + Send + Sync + 'f,
427 ) -> Tensor<TOut, B, DOut>
428 where
429 R2: DataAPI<Data = <B as DeviceRawAPI<T2>>::Raw>,
430 D2: DimAPI,
431 DOut: DimAPI,
432 D: DimMaxAPI<D2, Max = DOut>,
433 T: Clone,
434 T2: Clone,
435 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
436 B: DeviceOp_MutC_RefA_RefB_API<T, T2, TOut, DOut, dyn Fn(&mut MaybeUninit<TOut>, &T, &T2) + Send + Sync + 'f>,
437 {
438 self.mapb_f(other, move |x, y| f(x.clone(), y.clone())).rstsr_unwrap()
439 }
440}
441
442#[cfg(test)]
445mod tests_fnmut {
446 use super::*;
447
448 #[test]
449 fn test_mapv() {
450 let device = DeviceCpuSerial::default();
451 let mut i = 0;
452 let f = |x| {
453 i += 1;
454 x * 2.0
455 };
456 let a = asarray((vec![1., 2., 3., 4.], &device));
457 let b = a.mapv_fnmut(f);
458 assert!(allclose_f64(&b, &vec![2., 4., 6., 8.].into()));
459 assert_eq!(i, 4);
460 println!("{b:?}");
461 }
462
463 #[test]
464 fn test_mapv_binary() {
465 let device = DeviceCpuSerial::default();
466 let mut i = 0;
467 let f = |x, y| {
468 i += 1;
469 2.0 * x + 3.0 * y
470 };
471 #[cfg(not(feature = "col_major"))]
472 {
473 let a = linspace((1., 6., 6, &device)).into_shape_assume_contig([2, 3]);
477 let b = linspace((1., 3., 3, &device));
478 let c = a.mapvb_fnmut(&b, f);
479 assert_eq!(i, 6);
480 println!("{c:?}");
481 assert!(allclose_f64(&c.raw().into(), &vec![5., 10., 15., 11., 16., 21.].into()));
482 }
483 #[cfg(feature = "col_major")]
484 {
485 let a = linspace((1., 6., 6, &device)).into_shape_assume_contig([3, 2]);
489 let b = linspace((1., 3., 3, &device));
490 let c = a.mapvb_fnmut(&b, f);
491 assert_eq!(i, 6);
492 println!("{c:?}");
493 assert!(allclose_f64(&c.raw().into(), &vec![5., 10., 15., 11., 16., 21.].into()));
494 }
495 }
496}
497
498#[cfg(test)]
499mod tests_sync {
500 use super::*;
501
502 #[test]
503 fn test_mapv() {
504 let f = |x| x * 2.0;
505 let a = asarray(vec![1., 2., 3., 4.]);
506 let b = a.mapv(f);
507 assert!(allclose_f64(&b, &vec![2., 4., 6., 8.].into()));
508 println!("{b:?}");
509 }
510
511 #[test]
512 fn test_mapv_binary() {
513 let f = |x, y| 2.0 * x + 3.0 * y;
514 #[cfg(not(feature = "col_major"))]
515 {
516 let a = linspace((1., 6., 6)).into_shape_assume_contig([2, 3]);
520 let b = linspace((1., 3., 3));
521 let c = a.mapvb(&b, f);
522 assert!(allclose_f64(&c.raw().into(), &vec![5., 10., 15., 11., 16., 21.].into()));
523 }
524 #[cfg(feature = "col_major")]
525 {
526 let a = linspace((1., 6., 6)).into_shape_assume_contig([3, 2]);
530 let b = linspace((1., 3., 3));
531 let c = a.mapvb(&b, f);
532 assert!(allclose_f64(&c.raw().into(), &vec![5., 10., 15., 11., 16., 21.].into()));
533 }
534 }
535}