1use crate::prelude_dev::*;
2
3pub fn op_mutc_refa_refb_func<RA, RB, RC, DA, DB, DC, TA, TB, TC, B, F>(
6 c: &mut TensorAny<RC, TC, B, DC>,
7 a: &TensorAny<RA, TA, B, DA>,
8 b: &TensorAny<RB, TB, B, DB>,
9 f: &mut F,
10) -> Result<()>
11where
12 RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
14 RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
15 RC: DataMutAPI<Data = <B as DeviceRawAPI<TC>>::Raw>,
16 DA: DimAPI,
17 DB: DimAPI,
18 DC: DimAPI,
19 B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
20 DC: DimMaxAPI<DA, Max = DC> + DimMaxAPI<DB, Max = DC>,
22 B: DeviceOp_MutC_RefA_RefB_API<TA, TB, TC, DC, F>,
24 F: FnMut(&mut TC, &TA, &TB),
25{
26 rstsr_assert!(c.device().same_device(a.device()), DeviceMismatch)?;
27 rstsr_assert!(c.device().same_device(b.device()), DeviceMismatch)?;
28 let lc = c.layout();
29 let la = a.layout();
30 let lb = b.layout();
31 let default_order = c.device().default_order();
32 let (lc_b, la_b) = broadcast_layout_to_first(lc, la, default_order)?;
35 rstsr_assert_eq!(lc_b, *lc, InvalidLayout)?;
36 let (lc_b, lb_b) = broadcast_layout_to_first(lc, lb, default_order)?;
37 rstsr_assert_eq!(lc_b, *lc, InvalidLayout)?;
38 let device = c.device().clone();
40 device.op_mutc_refa_refb_func(c.raw_mut(), &lc_b, a.raw(), &la_b, b.raw(), &lb_b, f)
41}
42
43pub fn op_refa_refb_func<RA, RB, DA, DB, DC, TA, TB, TC, B, F>(
44 a: &TensorAny<RA, TA, B, DA>,
45 b: &TensorAny<RB, TB, B, DB>,
46 f: &mut F,
47) -> Result<Tensor<TC, B, DC>>
48where
49 RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
51 RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
52 DA: DimAPI,
53 DB: DimAPI,
54 DC: DimAPI,
55 B: DeviceAPI<TA> + DeviceAPI<TB>,
56 DA: DimMaxAPI<DB, Max = DC>,
58 B: DeviceOp_MutC_RefA_RefB_API<TA, TB, TC, DC, F>,
60 B: DeviceCreationAnyAPI<TC>,
61 F: FnMut(&mut TC, &TA, &TB),
62{
63 rstsr_assert!(a.device().same_device(b.device()), DeviceMismatch)?;
64 let la = a.layout();
65 let lb = b.layout();
66 let default_order = a.device().default_order();
67 let (la_b, lb_b) = broadcast_layout(la, lb, default_order)?;
68 let lc_from_a = layout_for_array_copy(&la_b, TensorIterOrder::K)?;
70 let lc_from_b = layout_for_array_copy(&lb_b, TensorIterOrder::K)?;
71 let lc = if lc_from_a == lc_from_b {
72 lc_from_a
73 } else {
74 match default_order {
75 RowMajor => la_b.shape().c(),
76 ColMajor => la_b.shape().f(),
77 }
78 };
79 let device = a.device();
81 let mut storage_c = unsafe { device.empty_impl(lc.bounds_index()?.1)? };
82 device.op_mutc_refa_refb_func(storage_c.raw_mut(), &lc, a.raw(), &la_b, b.raw(), &lb_b, f)?;
84 Tensor::new_f(storage_c, lc)
86}
87
88pub fn op_muta_refb_func<RA, RB, DA, DB, TA, TB, B, F>(
89 a: &mut TensorAny<RA, TA, B, DA>,
90 b: &TensorAny<RB, TB, B, DB>,
91 f: &mut F,
92) -> Result<()>
93where
94 RA: DataMutAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
96 RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
97 DA: DimAPI,
98 DB: DimAPI,
99 B: DeviceAPI<TA> + DeviceAPI<TB>,
100 DA: DimMaxAPI<DB, Max = DA>,
102 B: DeviceOp_MutA_RefB_API<TA, TB, DA, F>,
104 F: FnMut(&mut TA, &TB),
105{
106 rstsr_assert!(a.device().same_device(b.device()), DeviceMismatch)?;
107 let la = a.layout();
108 let lb = b.layout();
109 let default_order = a.device().default_order();
110 let (la_b, lb_b) = broadcast_layout_to_first(la, lb, default_order)?;
113 rstsr_assert_eq!(la_b, *la, InvalidLayout)?;
114 let device = a.device().clone();
116 device.op_muta_refb_func(a.raw_mut(), &la_b, b.raw(), &lb_b, f)
117}
118
119pub fn op_muta_func<R, T, D, B, F>(a: &mut TensorAny<R, T, B, D>, f: &mut F) -> Result<()>
120where
121 R: DataMutAPI<Data = B::Raw>,
122 D: DimAPI,
123 B: DeviceAPI<T>,
124 B: DeviceOp_MutA_API<T, D, F>,
125 F: FnMut(&mut T),
126{
127 let la = a.layout().clone();
128 let device = a.device().clone();
129 device.op_muta_func(a.raw_mut(), &la, f)
130}
131
132