1use crate::prelude_dev::*;
2use core::mem::transmute;
3
4impl<R, T, B, D> TensorAny<R, T, B, D>
9where
10 R: DataAPI<Data = B::Raw>,
11 D: DimAPI,
12 B: DeviceAPI<T>,
13{
14 pub fn map_fnmut_f<'f, TOut>(&self, mut f: impl FnMut(&T) -> TOut + 'f) -> Result<Tensor<TOut, B, D>>
17 where
18 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
19 B: Op_MutA_RefB_API<TOut, T, D, dyn FnMut(&mut MaybeUninit<TOut>, &T) + 'f>,
20 {
21 let la = self.layout();
22 let lc = layout_for_array_copy(la, TensorIterOrder::default())?;
23 let device = self.device();
24 let mut storage_c = device.uninit_impl(lc.bounds_index()?.1)?;
25 let mut f_inner = move |c: &mut MaybeUninit<TOut>, a: &T| {
26 c.write(f(a));
27 };
28 device.op_muta_refb_func(storage_c.raw_mut(), &lc, self.raw(), la, &mut f_inner)?;
29 let storage_c = unsafe { B::assume_init_impl(storage_c) }?;
30 return Tensor::new_f(storage_c, lc);
31 }
32
33 pub fn map_fnmut<'f, TOut>(&self, f: impl FnMut(&T) -> TOut + 'f) -> Tensor<TOut, B, D>
36 where
37 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
38 B: Op_MutA_RefB_API<TOut, T, D, dyn FnMut(&mut MaybeUninit<TOut>, &T) + 'f>,
39 {
40 self.map_fnmut_f(f).rstsr_unwrap()
41 }
42
43 pub fn mapv_fnmut_f<'f, TOut>(&self, mut f: impl FnMut(T) -> TOut + 'f) -> Result<Tensor<TOut, B, D>>
46 where
47 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
48 T: Clone,
49 B: Op_MutA_RefB_API<TOut, T, D, dyn FnMut(&mut MaybeUninit<TOut>, &T) + 'f>,
50 {
51 self.map_fnmut_f(move |x| f(x.clone()))
52 }
53
54 pub fn mapv_fnmut<'f, TOut>(&self, mut f: impl FnMut(T) -> TOut + 'f) -> Tensor<TOut, B, D>
57 where
58 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
59 T: Clone,
60 B: Op_MutA_RefB_API<TOut, T, D, dyn FnMut(&mut MaybeUninit<TOut>, &T) + 'f>,
61 {
62 self.map_fnmut_f(move |x| f(x.clone())).rstsr_unwrap()
63 }
64
65 pub fn mapi_fnmut_f<'f>(&mut self, mut f: impl FnMut(&mut T) + 'f) -> Result<()>
68 where
69 R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
70 B: Op_MutA_API<T, D, dyn FnMut(&mut MaybeUninit<T>) + 'f>,
71 {
72 let (la, _) = greedy_layout(self.layout(), false);
73 let device = self.device().clone();
74 let self_raw_mut = unsafe {
75 transmute::<&mut <B as DeviceRawAPI<T>>::Raw, &mut <B as DeviceRawAPI<MaybeUninit<T>>>::Raw>(self.raw_mut())
76 };
77 let mut f_inner = move |x: &mut MaybeUninit<T>| {
78 let x_ref = unsafe { x.assume_init_mut() };
79 f(x_ref);
80 };
81 device.op_muta_func(self_raw_mut, &la, &mut f_inner)
82 }
83
84 pub fn mapi_fnmut<'f>(&mut self, f: impl FnMut(&mut T) + 'f)
87 where
88 R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
89 B: Op_MutA_API<T, D, dyn FnMut(&mut MaybeUninit<T>) + 'f>,
90 {
91 self.mapi_fnmut_f(f).rstsr_unwrap()
92 }
93
94 pub fn mapvi_fnmut_f<'f>(&mut self, mut f: impl FnMut(T) -> T + 'f) -> Result<()>
97 where
98 R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
99 T: Clone,
100 B: Op_MutA_API<T, D, dyn FnMut(&mut MaybeUninit<T>) + 'f>,
101 {
102 self.mapi_fnmut_f(move |x| *x = f(x.clone()))
103 }
104
105 pub fn mapvi_fnmut<'f>(&mut self, f: impl FnMut(T) -> T + 'f)
108 where
109 R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
110 T: Clone,
111 B: Op_MutA_API<T, D, dyn FnMut(&mut MaybeUninit<T>) + 'f>,
112 {
113 self.mapvi_fnmut_f(f).rstsr_unwrap()
114 }
115}
116
117impl<R, T, B, D> TensorAny<R, T, B, D>
120where
121 R: DataAPI<Data = B::Raw>,
122 D: DimAPI,
123 B: DeviceAPI<T>,
124 T: Clone,
125{
126 pub fn mapb_fnmut_f<'f, R2, T2, D2, DOut, TOut>(
127 &self,
128 other: &TensorAny<R2, T2, B, D2>,
129 mut f: impl FnMut(&T, &T2) -> TOut + 'f,
130 ) -> Result<Tensor<TOut, B, DOut>>
131 where
132 R2: DataAPI<Data = <B as DeviceRawAPI<T2>>::Raw>,
133 D2: DimAPI,
134 DOut: DimAPI,
135 D: DimMaxAPI<D2, Max = DOut>,
136 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
137 B: Op_MutC_RefA_RefB_API<T, T2, TOut, DOut, dyn FnMut(&mut MaybeUninit<TOut>, &T, &T2) + 'f>,
138 {
139 let a = self.view();
141 let b = other.view();
142 rstsr_assert!(a.device().same_device(b.device()), DeviceMismatch)?;
144 let la = a.layout();
145 let lb = b.layout();
146 let default_order = a.device().default_order();
147 let (la_b, lb_b) = broadcast_layout(la, lb, default_order)?;
148 let lc = match TensorIterOrder::default() {
150 TensorIterOrder::C => la_b.shape().c(),
151 TensorIterOrder::F => la_b.shape().f(),
152 _ => get_layout_for_binary_op(&la_b, &lb_b, default_order)?,
153 };
154 let device = self.device();
156 let mut storage_c = device.uninit_impl(lc.bounds_index()?.1)?;
157 let mut f_inner = move |c: &mut MaybeUninit<TOut>, a: &T, b: &T2| {
158 c.write(f(a, b));
159 };
160 device.op_mutc_refa_refb_func(storage_c.raw_mut(), &lc, self.raw(), &la_b, other.raw(), &lb_b, &mut f_inner)?;
161 let storage_c = unsafe { B::assume_init_impl(storage_c) }?;
162 Tensor::new_f(storage_c, lc)
163 }
164
165 pub fn mapb_fnmut<'f, R2, T2, D2, DOut, TOut>(
166 &self,
167 other: &TensorAny<R2, T2, B, D2>,
168 f: impl FnMut(&T, &T2) -> TOut + 'f,
169 ) -> Tensor<TOut, B, DOut>
170 where
171 R2: DataAPI<Data = <B as DeviceRawAPI<T2>>::Raw>,
172 D2: DimAPI,
173 DOut: DimAPI,
174 D: DimMaxAPI<D2, Max = DOut>,
175 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
176 B: Op_MutC_RefA_RefB_API<T, T2, TOut, DOut, dyn FnMut(&mut MaybeUninit<TOut>, &T, &T2) + 'f>,
177 {
178 self.mapb_fnmut_f(other, f).rstsr_unwrap()
179 }
180
181 pub fn mapvb_fnmut_f<'f, R2, T2, D2, DOut, TOut>(
182 &self,
183 other: &TensorAny<R2, T2, B, D2>,
184 mut f: impl FnMut(T, T2) -> TOut + 'f,
185 ) -> Result<Tensor<TOut, B, DOut>>
186 where
187 R2: DataAPI<Data = <B as DeviceRawAPI<T2>>::Raw>,
188 D2: DimAPI,
189 DOut: DimAPI,
190 D: DimMaxAPI<D2, Max = DOut>,
191 T: Clone,
192 T2: Clone,
193 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
194 B: Op_MutC_RefA_RefB_API<T, T2, TOut, DOut, dyn FnMut(&mut MaybeUninit<TOut>, &T, &T2) + 'f>,
195 {
196 self.mapb_fnmut_f(other, move |x, y| f(x.clone(), y.clone()))
197 }
198
199 pub fn mapvb_fnmut<'f, R2, T2, D2, DOut, TOut>(
200 &self,
201 other: &TensorAny<R2, T2, B, D2>,
202 mut f: impl FnMut(T, T2) -> TOut + 'f,
203 ) -> Tensor<TOut, B, DOut>
204 where
205 R2: DataAPI<Data = <B as DeviceRawAPI<T2>>::Raw>,
206 D2: DimAPI,
207 DOut: DimAPI,
208 D: DimMaxAPI<D2, Max = DOut>,
209 T: Clone,
210 T2: Clone,
211 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
212 B: Op_MutC_RefA_RefB_API<T, T2, TOut, DOut, dyn FnMut(&mut MaybeUninit<TOut>, &T, &T2) + 'f>,
213 {
214 self.mapb_fnmut_f(other, move |x, y| f(x.clone(), y.clone())).rstsr_unwrap()
215 }
216}
217
218impl<R, T, B, D> TensorAny<R, T, B, D>
225where
226 R: DataAPI<Data = B::Raw>,
227 D: DimAPI,
228 B: DeviceAPI<T>,
229{
230 pub fn map_f<'f, TOut>(&self, f: impl Fn(&T) -> TOut + Send + Sync + 'f) -> Result<Tensor<TOut, B, D>>
233 where
234 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
235 B: Op_MutA_RefB_API<TOut, T, D, dyn Fn(&mut MaybeUninit<TOut>, &T) + Send + Sync + 'f>,
236 {
237 let la = self.layout();
238 let lc = layout_for_array_copy(la, TensorIterOrder::default())?;
239 let device = self.device();
240 let mut storage_c = device.uninit_impl(lc.bounds_index()?.1)?;
241 let mut f_inner = move |c: &mut MaybeUninit<TOut>, a: &T| {
242 c.write(f(a));
243 };
244 device.op_muta_refb_func(storage_c.raw_mut(), &lc, self.raw(), la, &mut f_inner)?;
245 let storage_c = unsafe { B::assume_init_impl(storage_c) }?;
246 return Tensor::new_f(storage_c, lc);
247 }
248
249 pub fn map<'f, TOut>(&self, f: impl Fn(&T) -> TOut + Send + Sync + 'f) -> Tensor<TOut, B, D>
252 where
253 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
254 B: Op_MutA_RefB_API<TOut, T, D, dyn Fn(&mut MaybeUninit<TOut>, &T) + Send + Sync + 'f>,
255 {
256 self.map_f(f).rstsr_unwrap()
257 }
258
259 pub fn mapv_f<'f, TOut>(&self, f: impl Fn(T) -> TOut + Send + Sync + 'f) -> Result<Tensor<TOut, B, D>>
262 where
263 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
264 T: Clone,
265 B: Op_MutA_RefB_API<TOut, T, D, dyn Fn(&mut MaybeUninit<TOut>, &T) + Send + Sync + 'f>,
266 {
267 self.map_f(move |x| f(x.clone()))
268 }
269
270 pub fn mapv<'f, TOut>(&self, f: impl Fn(T) -> TOut + Send + Sync + 'f) -> Tensor<TOut, B, D>
273 where
274 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
275 T: Clone,
276 B: Op_MutA_RefB_API<TOut, T, D, dyn Fn(&mut MaybeUninit<TOut>, &T) + Send + Sync + 'f>,
277 {
278 self.map_f(move |x| f(x.clone())).rstsr_unwrap()
279 }
280
281 pub fn mapi_f<'f>(&mut self, f: impl Fn(&mut T) + Send + Sync + 'f) -> Result<()>
284 where
285 R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
286 B: Op_MutA_API<T, D, dyn Fn(&mut MaybeUninit<T>) + Send + Sync + 'f>,
287 {
288 let (la, _) = greedy_layout(self.layout(), false);
289 let device = self.device().clone();
290 let self_raw_mut = unsafe {
291 transmute::<&mut <B as DeviceRawAPI<T>>::Raw, &mut <B as DeviceRawAPI<MaybeUninit<T>>>::Raw>(self.raw_mut())
292 };
293 let mut f_inner = move |x: &mut MaybeUninit<T>| {
294 let x_ref = unsafe { x.assume_init_mut() };
295 f(x_ref);
296 };
297 device.op_muta_func(self_raw_mut, &la, &mut f_inner)
298 }
299
300 pub fn mapi<'f>(&mut self, f: impl Fn(&mut T) + Send + Sync + 'f)
303 where
304 R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
305 B: Op_MutA_API<T, D, dyn Fn(&mut MaybeUninit<T>) + Send + Sync + 'f>,
306 {
307 self.mapi_f(f).rstsr_unwrap()
308 }
309
310 pub fn mapvi_f<'f>(&mut self, f: impl Fn(T) -> T + Send + Sync + 'f) -> Result<()>
313 where
314 R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
315 T: Clone,
316 B: Op_MutA_API<T, D, dyn Fn(&mut MaybeUninit<T>) + Send + Sync + 'f>,
317 {
318 self.mapi_f(move |x| *x = f(x.clone()))
319 }
320
321 pub fn mapvi<'f>(&mut self, f: impl Fn(T) -> T + Send + Sync + 'f)
324 where
325 R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
326 T: Clone,
327 B: Op_MutA_API<T, D, dyn Fn(&mut MaybeUninit<T>) + Send + Sync + 'f>,
328 {
329 self.mapvi_f(f).rstsr_unwrap()
330 }
331}
332
333impl<R, T, B, D> TensorAny<R, T, B, D>
336where
337 R: DataAPI<Data = B::Raw>,
338 D: DimAPI,
339 B: DeviceAPI<T>,
340 T: Clone,
341{
342 pub fn mapb_f<'f, R2, T2, D2, DOut, TOut>(
343 &self,
344 other: &TensorAny<R2, T2, B, D2>,
345 f: impl Fn(&T, &T2) -> TOut + Send + Sync + 'f,
346 ) -> Result<Tensor<TOut, B, DOut>>
347 where
348 R2: DataAPI<Data = <B as DeviceRawAPI<T2>>::Raw>,
349 D2: DimAPI,
350 DOut: DimAPI,
351 D: DimMaxAPI<D2, Max = DOut>,
352 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
353 B: Op_MutC_RefA_RefB_API<T, T2, TOut, DOut, dyn Fn(&mut MaybeUninit<TOut>, &T, &T2) + Send + Sync + 'f>,
354 {
355 let a = self.view();
357 let b = other.view();
358 rstsr_assert!(a.device().same_device(b.device()), DeviceMismatch)?;
360 let la = a.layout();
361 let lb = b.layout();
362 let default_order = a.device().default_order();
363 let (la_b, lb_b) = broadcast_layout(la, lb, default_order)?;
364 let lc = match TensorIterOrder::default() {
366 TensorIterOrder::C => la_b.shape().c(),
367 TensorIterOrder::F => la_b.shape().f(),
368 _ => get_layout_for_binary_op(&la_b, &lb_b, default_order)?,
369 };
370 let device = self.device();
372 let mut storage_c = device.uninit_impl(lc.bounds_index()?.1)?;
373 let mut f_inner = move |c: &mut MaybeUninit<TOut>, a: &T, b: &T2| {
374 c.write(f(a, b));
375 };
376 device.op_mutc_refa_refb_func(storage_c.raw_mut(), &lc, self.raw(), &la_b, other.raw(), &lb_b, &mut f_inner)?;
377 let storage_c = unsafe { B::assume_init_impl(storage_c) }?;
378 Tensor::new_f(storage_c, lc)
379 }
380
381 pub fn mapb<'f, R2, T2, D2, DOut, TOut>(
382 &self,
383 other: &TensorAny<R2, T2, B, D2>,
384 f: impl Fn(&T, &T2) -> TOut + Send + Sync + 'f,
385 ) -> Tensor<TOut, B, DOut>
386 where
387 R2: DataAPI<Data = <B as DeviceRawAPI<T2>>::Raw>,
388 D2: DimAPI,
389 DOut: DimAPI,
390 D: DimMaxAPI<D2, Max = DOut>,
391 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
392 B: Op_MutC_RefA_RefB_API<T, T2, TOut, DOut, dyn Fn(&mut MaybeUninit<TOut>, &T, &T2) + Send + Sync + 'f>,
393 {
394 self.mapb_f(other, f).rstsr_unwrap()
395 }
396
397 pub fn mapvb_f<'f, R2, T2, D2, DOut, TOut>(
398 &self,
399 other: &TensorAny<R2, T2, B, D2>,
400 f: impl Fn(T, T2) -> TOut + Send + Sync + 'f,
401 ) -> Result<Tensor<TOut, B, DOut>>
402 where
403 R2: DataAPI<Data = <B as DeviceRawAPI<T2>>::Raw>,
404 D2: DimAPI,
405 DOut: DimAPI,
406 D: DimMaxAPI<D2, Max = DOut>,
407 T: Clone,
408 T2: Clone,
409 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
410 B: Op_MutC_RefA_RefB_API<T, T2, TOut, DOut, dyn Fn(&mut MaybeUninit<TOut>, &T, &T2) + Send + Sync + 'f>,
411 {
412 self.mapb_f(other, move |x, y| f(x.clone(), y.clone()))
413 }
414
415 pub fn mapvb<'f, R2, T2, D2, DOut, TOut>(
416 &self,
417 other: &TensorAny<R2, T2, B, D2>,
418 f: impl Fn(T, T2) -> TOut + Send + Sync + 'f,
419 ) -> Tensor<TOut, B, DOut>
420 where
421 R2: DataAPI<Data = <B as DeviceRawAPI<T2>>::Raw>,
422 D2: DimAPI,
423 DOut: DimAPI,
424 D: DimMaxAPI<D2, Max = DOut>,
425 T: Clone,
426 T2: Clone,
427 B: DeviceAPI<TOut> + DeviceCreationAnyAPI<TOut>,
428 B: Op_MutC_RefA_RefB_API<T, T2, TOut, DOut, dyn Fn(&mut MaybeUninit<TOut>, &T, &T2) + Send + Sync + 'f>,
429 {
430 self.mapb_f(other, move |x, y| f(x.clone(), y.clone())).rstsr_unwrap()
431 }
432}
433
434#[cfg(test)]
437mod tests_fnmut {
438 use super::*;
439
440 #[test]
441 fn test_mapv() {
442 let device = DeviceCpuSerial::default();
443 let mut i = 0;
444 let f = |x| {
445 i += 1;
446 x * 2.0
447 };
448 let a = asarray((vec![1., 2., 3., 4.], &device));
449 let b = a.mapv_fnmut(f);
450 assert!(allclose_f64(&b, &vec![2., 4., 6., 8.].into()));
451 assert_eq!(i, 4);
452 println!("{b:?}");
453 }
454
455 #[test]
456 fn test_mapv_binary() {
457 let device = DeviceCpuSerial::default();
458 let mut i = 0;
459 let f = |x, y| {
460 i += 1;
461 2.0 * x + 3.0 * y
462 };
463 #[cfg(not(feature = "col_major"))]
464 {
465 let a = linspace((1., 6., 6, &device)).into_shape([2, 3]);
469 let b = linspace((1., 3., 3, &device));
470 let c = a.mapvb_fnmut(&b, f);
471 assert_eq!(i, 6);
472 println!("{c:?}");
473 assert!(allclose_f64(&c.raw().into(), &vec![5., 10., 15., 11., 16., 21.].into()));
474 }
475 #[cfg(feature = "col_major")]
476 {
477 let a = linspace((1., 6., 6, &device)).into_shape([3, 2]);
481 let b = linspace((1., 3., 3, &device));
482 let c = a.mapvb_fnmut(&b, f);
483 assert_eq!(i, 6);
484 println!("{c:?}");
485 assert!(allclose_f64(&c.raw().into(), &vec![5., 10., 15., 11., 16., 21.].into()));
486 }
487 }
488}
489
490#[cfg(test)]
491mod tests_sync {
492 use super::*;
493
494 #[test]
495 fn test_mapv() {
496 let f = |x| x * 2.0;
497 let a = asarray(vec![1., 2., 3., 4.]);
498 let b = a.mapv(f);
499 assert!(allclose_f64(&b, &vec![2., 4., 6., 8.].into()));
500 println!("{b:?}");
501 }
502
503 #[test]
504 fn test_mapv_binary() {
505 let f = |x, y| 2.0 * x + 3.0 * y;
506 #[cfg(not(feature = "col_major"))]
507 {
508 let a = linspace((1., 6., 6)).into_shape([2, 3]);
512 let b = linspace((1., 3., 3));
513 let c = a.mapvb(&b, f);
514 assert!(allclose_f64(&c.raw().into(), &vec![5., 10., 15., 11., 16., 21.].into()));
515 }
516 #[cfg(feature = "col_major")]
517 {
518 let a = linspace((1., 6., 6)).into_shape([3, 2]);
522 let b = linspace((1., 3., 3));
523 let c = a.mapvb(&b, f);
524 assert!(allclose_f64(&c.raw().into(), &vec![5., 10., 15., 11., 16., 21.].into()));
525 }
526 }
527}