1#[cfg(feature = "opencl")]
2use crate::{
3 cl_diagflat,
4 ops::{cl_to_cpu_lr, cl_to_cpu_s},
5};
6use crate::{
7 matrix_multiply::MatrixMultiply, ColOp, FnsOps, Matrix, MaxOps, SumOverOps, TransposeOp,
8};
9use custos::{number::Float, range, Device, GenericBlas, CPU};
10#[cfg(feature = "opencl")]
11use custos::{CDatatype, OpenCL};
12
13#[cfg(feature = "cuda")]
14use crate::{cu_to_cpu_lr, cu_to_cpu_s};
15#[cfg(feature = "cuda")]
16use custos::CUDA;
17
18impl<'a, T, D: SoftmaxOps<T>> Matrix<'a, T, D> {
19 pub fn softmax(&self) -> Matrix<'a, T, D> {
20 self.device().softmax(self)
21 }
22 pub fn softmax_grad(&self, activated: &Matrix<T, D>) -> Matrix<'a, T, D> {
23 self.device().softmax_grad(activated, self)
24 }
25}
26
27pub trait SoftmaxOps<T, D: Device = Self>: Device {
28 fn softmax(&self, inputs: &Matrix<T, D>) -> Matrix<T, Self>;
29 fn softmax_grad(&self, activated: &Matrix<T, D>, grads: &Matrix<T, D>) -> Matrix<T, Self>;
30}
31
32#[cfg(feature = "cpu")]
33impl<T: Float + GenericBlas + MatrixMultiply> SoftmaxOps<T> for CPU
34where
35 CPU: ColOp<T>,
36{
37 fn softmax(&self, inputs: &Matrix<T>) -> Matrix<T> {
38 let exp = self.exp(&self.sub_col(inputs, &self.max_cols(inputs)));
39 self.div_col(&exp, &self.sum_cols(&exp))
40 }
41
42 #[cfg(not(feature = "safe"))]
43 fn softmax_grad(&self, activated: &Matrix<T>, grads: &Matrix<T>) -> Matrix<T> {
44 use custos::Cache;
45
46 use crate::{BaseOps, Gemm};
47
48 let mut data: Matrix<T> = (Cache::get(self, grads.len(), ()), grads.dims()).into();
49
50 let rows = grads.rows();
51 let cols = grads.cols();
52
53 for idx in range(rows - 1) {
54 let index = idx * cols;
55
56 let single_out = Matrix::from((
57 self,
58 (&activated[index..index + cols]).as_ptr() as *mut T,
59 (cols, 1),
60 ));
61 let single_grad = Matrix::from((
62 self,
63 (&grads[index..index + cols]).as_ptr() as *mut T,
64 (cols, 1),
65 ));
66
67 let diagflat = single_out.diagflat();
68
69 let jacobian_matrix =
71 self.sub(&diagflat, &self.gemm(&single_out, &single_out.T::<()>()));
72
73 let res: Matrix<T> = jacobian_matrix.gemm(&single_grad);
75
76 let data_row = &mut data[index..index + cols];
77 data_row.copy_from_slice(&res);
78 }
79 data
80 }
81
82 #[cfg(feature = "safe")]
83 fn softmax_grad(&self, activated: &Matrix<T>, grads: &Matrix<T>) -> Matrix<T> {
84 use crate::{BaseOps, Gemm};
85
86 let device = CPU::new();
87 let mut data = cached(self, grads.dims());
88
89 let rows = grads.rows();
90 let cols = grads.cols();
91
92 for idx in range(rows - 1) {
93 let index = idx * cols;
94
95 let single_out =
96 Matrix::from((&device, (cols, 1), &activated[index..index + cols].to_vec()));
97
98 let single_grad =
99 Matrix::from((&device, (cols, 1), &grads[index..index + cols].to_vec()));
100
101 let diagflat = self.diagflat(&single_out);
102
103 let jacobian_matrix = self.sub(
104 &diagflat,
105 &self.gemm(&single_out, &self.transpose(&single_out)),
106 );
107 let res = self.gemm(&jacobian_matrix, &single_grad);
109
110 let data_row = &mut data[index..index + cols];
111 data_row.copy_from_slice(res.as_slice());
112 }
113 data
114 }
115}
116
117#[cfg(feature = "cuda")]
118impl<T: GenericBlas + MatrixMultiply + Float> SoftmaxOps<T> for CUDA {
119 fn softmax(&self, inputs: &Matrix<T, Self>) -> Matrix<T, Self> {
120 cu_to_cpu_s(self, inputs, |cpu, x| cpu.softmax(&x))
121 }
122
123 fn softmax_grad(
124 &self,
125 activated: &Matrix<T, Self>,
126 grads: &Matrix<T, Self>,
127 ) -> Matrix<T, Self> {
128 cu_to_cpu_lr(self, activated, grads, |cpu, activated, grads| {
129 cpu.softmax_grad(activated, grads)
130 })
131 }
132}
133
134#[cfg(feature = "opencl")]
135impl<T: GenericBlas + MatrixMultiply + Float> SoftmaxOps<T> for OpenCL {
137 fn softmax(&self, inputs: &Matrix<T, Self>) -> Matrix<T, Self> {
138 cl_to_cpu_s(self, inputs, |device, inputs| device.softmax(inputs))
139 }
140
141 fn softmax_grad(
142 &self,
143 activated: &Matrix<T, Self>,
144 grads: &Matrix<T, Self>,
145 ) -> Matrix<T, Self> {
146 cl_to_cpu_lr(self, activated, grads, |device, activated, grads| {
147 device.softmax_grad(activated, grads)
148 })
149 }
150}
151
152#[cfg(feature = "opencl")]
153pub fn cl_softmax<'a, T: CDatatype>(
154 device: &'a OpenCL,
155 mut activated: Matrix<T, OpenCL>,
156 grads: &Matrix<T, OpenCL>,
157) -> custos::Result<Matrix<'a, T, OpenCL>> {
158 use crate::{cl_tew, Gemm, SumOverOps};
159
160 let rows = grads.rows();
161 let cols = grads.cols();
162
163 let diag = cl_diagflat(device, &activated, activated.rows(), activated.cols())?;
164
165 activated.reshape((cols, rows));
167
168 let jacobian = cl_tew(
171 device,
172 &diag,
173 &device.gemm(&activated, &device.transpose(&activated)),
174 "-",
175 )?;
176
177 println!("jacobian: {jacobian:?}");
178
179 let jacobian = (jacobian, rows, cols * cols).into();
180 let mut jacobian = device.sum_rows(&jacobian);
181 jacobian.reshape((cols, cols));
182
183 let res = device.gemm(grads, &jacobian);
185 Ok(res)
186}