custos_math/
cpu.rs

1use crate::Matrix;
2use custos::{number::Number, Alloc, MainMemory, Shape};
3
4pub fn scalar_apply<'a, T, F, D, S, Host>(
5    device: &'a Host,
6    lhs: &Matrix<T, D, S>,
7    scalar: T,
8    f: F,
9) -> Matrix<'a, T, Host, S>
10where
11    T: Number,
12    F: Fn(&mut T, T, T),
13    D: MainMemory,
14    S: Shape,
15    Host: for<'b> Alloc<'b, T, S> + MainMemory,
16{
17    let mut out = device.retrieve(lhs.len(), lhs.node.idx);
18    scalar_apply_slice(lhs, &mut out, scalar, f);
19    (out, lhs.dims()).into()
20}
21
22#[inline]
23pub fn scalar_apply_slice<T, F>(lhs: &[T], out: &mut [T], scalar: T, f: F)
24where
25    T: Copy,
26    F: Fn(&mut T, T, T),
27{
28    for (idx, value) in out.iter_mut().enumerate() {
29        f(value, lhs[idx], scalar)
30    }
31}
32
33pub fn row_op_slice_mut<T, F>(lhs: &[T], lrows: usize, lcols: usize, rhs: &[T], out: &mut [T], f: F)
34where
35    T: Copy,
36    F: Fn(&mut T, T, T),
37{
38    for i in 0..lrows {
39        let index = i * lcols;
40        let x = &lhs[index..index + lcols];
41
42        for (idx, value) in rhs.iter().enumerate() {
43            f(&mut out[index + idx], x[idx], *value);
44        }
45    }
46}
47
48pub fn row_op_slice_lhs<T, F>(lhs: &mut [T], lhs_rows: usize, lhs_cols: usize, rhs: &[T], f: F)
49where
50    T: Copy,
51    F: Fn(&mut T, T),
52{
53    for i in 0..lhs_rows {
54        let index = i * lhs_cols;
55
56        for (idx, value) in rhs.iter().enumerate() {
57            f(&mut lhs[index + idx], *value);
58        }
59    }
60}
61
62pub fn row_op<'a, T, F, D, Host, LS: Shape, RS: Shape>(
63    device: &'a Host,
64    lhs: &Matrix<T, D, LS>,
65    rhs: &Matrix<T, D, RS>,
66    f: F,
67) -> Matrix<'a, T, Host, LS>
68where
69    T: Number,
70    F: Fn(&mut T, T, T),
71    D: MainMemory,
72    Host: for<'b> Alloc<'b, T, LS> + MainMemory,
73{
74    assert!(rhs.rows() == 1 && rhs.cols() == lhs.cols());
75
76    let mut out = device.retrieve(lhs.len(), [lhs.node.idx, rhs.node.idx]);
77    row_op_slice_mut(lhs, lhs.rows(), lhs.cols(), rhs, &mut out, f);
78    (out, lhs.dims()).into()
79}
80
81pub fn col_op<'a, T, F, D, Host>(
82    device: &'a Host,
83    lhs: &Matrix<T, D>,
84    rhs: &Matrix<T, D>,
85    f: F,
86) -> Matrix<'a, T, Host>
87where
88    T: Number,
89    F: Fn(&mut T, T, T),
90    D: MainMemory,
91    Host: for<'b> Alloc<'b, T> + MainMemory,
92{
93    let mut out = device.retrieve(lhs.len(), [lhs.node.idx, rhs.node.idx]);
94    col_op_slice_mut(lhs, lhs.rows(), lhs.cols(), rhs, &mut out, f);
95    (out, lhs.dims()).into()
96}
97
98pub fn col_op_slice_mut<T, F>(lhs: &[T], lrows: usize, lcols: usize, rhs: &[T], out: &mut [T], f: F)
99where
100    T: Number,
101    F: Fn(&mut T, T, T),
102{
103    let mut i = 0;
104    for (idx, rdata_value) in rhs.iter().enumerate().take(lrows) {
105        let index = idx * lcols;
106        let row = &lhs[index..index + lcols];
107        for data in row {
108            f(&mut out[i], *data, *rdata_value);
109            i += 1;
110        }
111    }
112}
113
114pub fn each_op<'a, T, F, D, S, Host>(
115    device: &'a Host,
116    x: &Matrix<T, D, S>,
117    f: F,
118) -> Matrix<'a, T, Host, S>
119where
120    T: Copy,
121    F: Fn(T) -> T,
122    D: MainMemory,
123    Host: for<'b> Alloc<'b, T, S> + MainMemory,
124    S: Shape,
125{
126    let mut out = device.retrieve(x.len(), x.node.idx);
127    each_op_slice(x, &mut out, f);
128    (out, x.dims()).into()
129}
130
131pub fn each_op_slice<T, F>(x: &[T], out: &mut [T], f: F)
132where
133    T: Copy,
134    F: Fn(T) -> T,
135{
136    for (idx, value) in out.iter_mut().enumerate() {
137        *value = f(x[idx]);
138    }
139}
140
141pub fn each_op_slice_mut<T, F>(x: &mut [T], f: F)
142where
143    T: Copy,
144    F: Fn(T) -> T,
145{
146    for value in x.iter_mut() {
147        *value = f(*value);
148    }
149}