rstsr_core/tensor/operators/
op_with_func.rs

1use crate::prelude_dev::*;
2
3/* #region op_func */
4
5pub fn op_mutc_refa_refb_func<RA, RB, RC, DA, DB, DC, TA, TB, TC, B, F>(
6    c: &mut TensorAny<RC, TC, B, DC>,
7    a: &TensorAny<RA, TA, B, DA>,
8    b: &TensorAny<RB, TB, B, DB>,
9    f: &mut F,
10) -> Result<()>
11where
12    // lifetime and data constraints
13    RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
14    RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
15    RC: DataMutAPI<Data = <B as DeviceRawAPI<TC>>::Raw>,
16    DA: DimAPI,
17    DB: DimAPI,
18    DC: DimAPI,
19    B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
20    // broadcast constraints
21    DC: DimMaxAPI<DA, Max = DC> + DimMaxAPI<DB, Max = DC>,
22    // operation constraints
23    B: DeviceOp_MutC_RefA_RefB_API<TA, TB, TC, DC, F>,
24    F: FnMut(&mut TC, &TA, &TB),
25{
26    rstsr_assert!(c.device().same_device(a.device()), DeviceMismatch)?;
27    rstsr_assert!(c.device().same_device(b.device()), DeviceMismatch)?;
28    let lc = c.layout();
29    let la = a.layout();
30    let lb = b.layout();
31    let default_order = c.device().default_order();
32    // all layouts should be broadcastable to lc
33    // we can first generate broadcasted shape, then check this
34    let (lc_b, la_b) = broadcast_layout_to_first(lc, la, default_order)?;
35    rstsr_assert_eq!(lc_b, *lc, InvalidLayout)?;
36    let (lc_b, lb_b) = broadcast_layout_to_first(lc, lb, default_order)?;
37    rstsr_assert_eq!(lc_b, *lc, InvalidLayout)?;
38    // op provided by device
39    let device = c.device().clone();
40    device.op_mutc_refa_refb_func(c.raw_mut(), &lc_b, a.raw(), &la_b, b.raw(), &lb_b, f)
41}
42
43pub fn op_refa_refb_func<RA, RB, DA, DB, DC, TA, TB, TC, B, F>(
44    a: &TensorAny<RA, TA, B, DA>,
45    b: &TensorAny<RB, TB, B, DB>,
46    f: &mut F,
47) -> Result<Tensor<TC, B, DC>>
48where
49    // lifetime and data constraints
50    RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
51    RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
52    DA: DimAPI,
53    DB: DimAPI,
54    DC: DimAPI,
55    B: DeviceAPI<TA> + DeviceAPI<TB>,
56    // broadcast constraints
57    DA: DimMaxAPI<DB, Max = DC>,
58    // operation constraints
59    B: DeviceOp_MutC_RefA_RefB_API<TA, TB, TC, DC, F>,
60    B: DeviceCreationAnyAPI<TC>,
61    F: FnMut(&mut TC, &TA, &TB),
62{
63    rstsr_assert!(a.device().same_device(b.device()), DeviceMismatch)?;
64    let la = a.layout();
65    let lb = b.layout();
66    let default_order = a.device().default_order();
67    let (la_b, lb_b) = broadcast_layout(la, lb, default_order)?;
68    // generate output layout
69    let lc_from_a = layout_for_array_copy(&la_b, TensorIterOrder::K)?;
70    let lc_from_b = layout_for_array_copy(&lb_b, TensorIterOrder::K)?;
71    let lc = if lc_from_a == lc_from_b {
72        lc_from_a
73    } else {
74        match default_order {
75            RowMajor => la_b.shape().c(),
76            ColMajor => la_b.shape().f(),
77        }
78    };
79    // generate empty c
80    let device = a.device();
81    let mut storage_c = unsafe { device.empty_impl(lc.bounds_index()?.1)? };
82    // add provided by device
83    device.op_mutc_refa_refb_func(storage_c.raw_mut(), &lc, a.raw(), &la_b, b.raw(), &lb_b, f)?;
84    // return tensor
85    Tensor::new_f(storage_c, lc)
86}
87
88pub fn op_muta_refb_func<RA, RB, DA, DB, TA, TB, B, F>(
89    a: &mut TensorAny<RA, TA, B, DA>,
90    b: &TensorAny<RB, TB, B, DB>,
91    f: &mut F,
92) -> Result<()>
93where
94    // lifetime and data constraints
95    RA: DataMutAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
96    RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
97    DA: DimAPI,
98    DB: DimAPI,
99    B: DeviceAPI<TA> + DeviceAPI<TB>,
100    // broadcast constraints
101    DA: DimMaxAPI<DB, Max = DA>,
102    // operation constraints
103    B: DeviceOp_MutA_RefB_API<TA, TB, DA, F>,
104    F: FnMut(&mut TA, &TB),
105{
106    rstsr_assert!(a.device().same_device(b.device()), DeviceMismatch)?;
107    let la = a.layout();
108    let lb = b.layout();
109    let default_order = a.device().default_order();
110    // all layouts should be broadcastable to lc
111    // we can first generate broadcasted shape, then check this
112    let (la_b, lb_b) = broadcast_layout_to_first(la, lb, default_order)?;
113    rstsr_assert_eq!(la_b, *la, InvalidLayout)?;
114    // op provided by device
115    let device = a.device().clone();
116    device.op_muta_refb_func(a.raw_mut(), &la_b, b.raw(), &lb_b, f)
117}
118
119pub fn op_muta_func<R, T, D, B, F>(a: &mut TensorAny<R, T, B, D>, f: &mut F) -> Result<()>
120where
121    R: DataMutAPI<Data = B::Raw>,
122    D: DimAPI,
123    B: DeviceAPI<T>,
124    B: DeviceOp_MutA_API<T, D, F>,
125    F: FnMut(&mut T),
126{
127    let la = a.layout().clone();
128    let device = a.device().clone();
129    device.op_muta_func(a.raw_mut(), &la, f)
130}
131
132/* #endregion */