1use crate::Matrix;
2use custos::{number::Number, Alloc, MainMemory, Shape};
3
4pub fn scalar_apply<'a, T, F, D, S, Host>(
5 device: &'a Host,
6 lhs: &Matrix<T, D, S>,
7 scalar: T,
8 f: F,
9) -> Matrix<'a, T, Host, S>
10where
11 T: Number,
12 F: Fn(&mut T, T, T),
13 D: MainMemory,
14 S: Shape,
15 Host: for<'b> Alloc<'b, T, S> + MainMemory,
16{
17 let mut out = device.retrieve(lhs.len(), lhs.node.idx);
18 scalar_apply_slice(lhs, &mut out, scalar, f);
19 (out, lhs.dims()).into()
20}
21
22#[inline]
23pub fn scalar_apply_slice<T, F>(lhs: &[T], out: &mut [T], scalar: T, f: F)
24where
25 T: Copy,
26 F: Fn(&mut T, T, T),
27{
28 for (idx, value) in out.iter_mut().enumerate() {
29 f(value, lhs[idx], scalar)
30 }
31}
32
33pub fn row_op_slice_mut<T, F>(lhs: &[T], lrows: usize, lcols: usize, rhs: &[T], out: &mut [T], f: F)
34where
35 T: Copy,
36 F: Fn(&mut T, T, T),
37{
38 for i in 0..lrows {
39 let index = i * lcols;
40 let x = &lhs[index..index + lcols];
41
42 for (idx, value) in rhs.iter().enumerate() {
43 f(&mut out[index + idx], x[idx], *value);
44 }
45 }
46}
47
48pub fn row_op_slice_lhs<T, F>(lhs: &mut [T], lhs_rows: usize, lhs_cols: usize, rhs: &[T], f: F)
49where
50 T: Copy,
51 F: Fn(&mut T, T),
52{
53 for i in 0..lhs_rows {
54 let index = i * lhs_cols;
55
56 for (idx, value) in rhs.iter().enumerate() {
57 f(&mut lhs[index + idx], *value);
58 }
59 }
60}
61
62pub fn row_op<'a, T, F, D, Host, LS: Shape, RS: Shape>(
63 device: &'a Host,
64 lhs: &Matrix<T, D, LS>,
65 rhs: &Matrix<T, D, RS>,
66 f: F,
67) -> Matrix<'a, T, Host, LS>
68where
69 T: Number,
70 F: Fn(&mut T, T, T),
71 D: MainMemory,
72 Host: for<'b> Alloc<'b, T, LS> + MainMemory,
73{
74 assert!(rhs.rows() == 1 && rhs.cols() == lhs.cols());
75
76 let mut out = device.retrieve(lhs.len(), [lhs.node.idx, rhs.node.idx]);
77 row_op_slice_mut(lhs, lhs.rows(), lhs.cols(), rhs, &mut out, f);
78 (out, lhs.dims()).into()
79}
80
81pub fn col_op<'a, T, F, D, Host>(
82 device: &'a Host,
83 lhs: &Matrix<T, D>,
84 rhs: &Matrix<T, D>,
85 f: F,
86) -> Matrix<'a, T, Host>
87where
88 T: Number,
89 F: Fn(&mut T, T, T),
90 D: MainMemory,
91 Host: for<'b> Alloc<'b, T> + MainMemory,
92{
93 let mut out = device.retrieve(lhs.len(), [lhs.node.idx, rhs.node.idx]);
94 col_op_slice_mut(lhs, lhs.rows(), lhs.cols(), rhs, &mut out, f);
95 (out, lhs.dims()).into()
96}
97
98pub fn col_op_slice_mut<T, F>(lhs: &[T], lrows: usize, lcols: usize, rhs: &[T], out: &mut [T], f: F)
99where
100 T: Number,
101 F: Fn(&mut T, T, T),
102{
103 let mut i = 0;
104 for (idx, rdata_value) in rhs.iter().enumerate().take(lrows) {
105 let index = idx * lcols;
106 let row = &lhs[index..index + lcols];
107 for data in row {
108 f(&mut out[i], *data, *rdata_value);
109 i += 1;
110 }
111 }
112}
113
114pub fn each_op<'a, T, F, D, S, Host>(
115 device: &'a Host,
116 x: &Matrix<T, D, S>,
117 f: F,
118) -> Matrix<'a, T, Host, S>
119where
120 T: Copy,
121 F: Fn(T) -> T,
122 D: MainMemory,
123 Host: for<'b> Alloc<'b, T, S> + MainMemory,
124 S: Shape,
125{
126 let mut out = device.retrieve(x.len(), x.node.idx);
127 each_op_slice(x, &mut out, f);
128 (out, x.dims()).into()
129}
130
131pub fn each_op_slice<T, F>(x: &[T], out: &mut [T], f: F)
132where
133 T: Copy,
134 F: Fn(T) -> T,
135{
136 for (idx, value) in out.iter_mut().enumerate() {
137 *value = f(x[idx]);
138 }
139}
140
141pub fn each_op_slice_mut<T, F>(x: &mut [T], f: F)
142where
143 T: Copy,
144 F: Fn(T) -> T,
145{
146 for value in x.iter_mut() {
147 *value = f(*value);
148 }
149}