custos_math/ops/nn/
softmax.rs

1#[cfg(feature = "opencl")]
2use crate::{
3    cl_diagflat,
4    ops::{cl_to_cpu_lr, cl_to_cpu_s},
5};
6use crate::{
7    matrix_multiply::MatrixMultiply, ColOp, FnsOps, Matrix, MaxOps, SumOverOps, TransposeOp,
8};
9use custos::{number::Float, range, Device, GenericBlas, CPU};
10#[cfg(feature = "opencl")]
11use custos::{CDatatype, OpenCL};
12
13#[cfg(feature = "cuda")]
14use crate::{cu_to_cpu_lr, cu_to_cpu_s};
15#[cfg(feature = "cuda")]
16use custos::CUDA;
17
18impl<'a, T, D: SoftmaxOps<T>> Matrix<'a, T, D> {
19    pub fn softmax(&self) -> Matrix<'a, T, D> {
20        self.device().softmax(self)
21    }
22    pub fn softmax_grad(&self, activated: &Matrix<T, D>) -> Matrix<'a, T, D> {
23        self.device().softmax_grad(activated, self)
24    }
25}
26
27pub trait SoftmaxOps<T, D: Device = Self>: Device {
28    fn softmax(&self, inputs: &Matrix<T, D>) -> Matrix<T, Self>;
29    fn softmax_grad(&self, activated: &Matrix<T, D>, grads: &Matrix<T, D>) -> Matrix<T, Self>;
30}
31
32#[cfg(feature = "cpu")]
33impl<T: Float + GenericBlas + MatrixMultiply> SoftmaxOps<T> for CPU
34where
35    CPU: ColOp<T>,
36{
37    fn softmax(&self, inputs: &Matrix<T>) -> Matrix<T> {
38        let exp = self.exp(&self.sub_col(inputs, &self.max_cols(inputs)));
39        self.div_col(&exp, &self.sum_cols(&exp))
40    }
41
42    #[cfg(not(feature = "safe"))]
43    fn softmax_grad(&self, activated: &Matrix<T>, grads: &Matrix<T>) -> Matrix<T> {
44        use custos::Cache;
45
46        use crate::{BaseOps, Gemm};
47
48        let mut data: Matrix<T> = (Cache::get(self, grads.len(), ()), grads.dims()).into();
49
50        let rows = grads.rows();
51        let cols = grads.cols();
52
53        for idx in range(rows - 1) {
54            let index = idx * cols;
55
56            let single_out = Matrix::from((
57                self,
58                (&activated[index..index + cols]).as_ptr() as *mut T,
59                (cols, 1),
60            ));
61            let single_grad = Matrix::from((
62                self,
63                (&grads[index..index + cols]).as_ptr() as *mut T,
64                (cols, 1),
65            ));
66
67            let diagflat = single_out.diagflat();
68
69            // cols 1 x 1 cols
70            let jacobian_matrix =
71                self.sub(&diagflat, &self.gemm(&single_out, &single_out.T::<()>()));
72
73            //GenericBlas::gemm();
74            let res: Matrix<T> = jacobian_matrix.gemm(&single_grad);
75
76            let data_row = &mut data[index..index + cols];
77            data_row.copy_from_slice(&res);
78        }
79        data
80    }
81
82    #[cfg(feature = "safe")]
83    fn softmax_grad(&self, activated: &Matrix<T>, grads: &Matrix<T>) -> Matrix<T> {
84        use crate::{BaseOps, Gemm};
85
86        let device = CPU::new();
87        let mut data = cached(self, grads.dims());
88
89        let rows = grads.rows();
90        let cols = grads.cols();
91
92        for idx in range(rows - 1) {
93            let index = idx * cols;
94
95            let single_out =
96                Matrix::from((&device, (cols, 1), &activated[index..index + cols].to_vec()));
97
98            let single_grad =
99                Matrix::from((&device, (cols, 1), &grads[index..index + cols].to_vec()));
100
101            let diagflat = self.diagflat(&single_out);
102
103            let jacobian_matrix = self.sub(
104                &diagflat,
105                &self.gemm(&single_out, &self.transpose(&single_out)),
106            );
107            //cols cols x cols 1
108            let res = self.gemm(&jacobian_matrix, &single_grad);
109
110            let data_row = &mut data[index..index + cols];
111            data_row.copy_from_slice(res.as_slice());
112        }
113        data
114    }
115}
116
117#[cfg(feature = "cuda")]
118impl<T: GenericBlas + MatrixMultiply + Float> SoftmaxOps<T> for CUDA {
119    fn softmax(&self, inputs: &Matrix<T, Self>) -> Matrix<T, Self> {
120        cu_to_cpu_s(self, inputs, |cpu, x| cpu.softmax(&x))
121    }
122
123    fn softmax_grad(
124        &self,
125        activated: &Matrix<T, Self>,
126        grads: &Matrix<T, Self>,
127    ) -> Matrix<T, Self> {
128        cu_to_cpu_lr(self, activated, grads, |cpu, activated, grads| {
129            cpu.softmax_grad(activated, grads)
130        })
131    }
132}
133
134#[cfg(feature = "opencl")]
135// TODO: Softmax running on the opencl device
136impl<T: GenericBlas + MatrixMultiply + Float> SoftmaxOps<T> for OpenCL {
137    fn softmax(&self, inputs: &Matrix<T, Self>) -> Matrix<T, Self> {
138        cl_to_cpu_s(self, inputs, |device, inputs| device.softmax(inputs))
139    }
140
141    fn softmax_grad(
142        &self,
143        activated: &Matrix<T, Self>,
144        grads: &Matrix<T, Self>,
145    ) -> Matrix<T, Self> {
146        cl_to_cpu_lr(self, activated, grads, |device, activated, grads| {
147            device.softmax_grad(activated, grads)
148        })
149    }
150}
151
152#[cfg(feature = "opencl")]
153pub fn cl_softmax<'a, T: CDatatype>(
154    device: &'a OpenCL,
155    mut activated: Matrix<T, OpenCL>,
156    grads: &Matrix<T, OpenCL>,
157) -> custos::Result<Matrix<'a, T, OpenCL>> {
158    use crate::{cl_tew, Gemm, SumOverOps};
159
160    let rows = grads.rows();
161    let cols = grads.cols();
162
163    let diag = cl_diagflat(device, &activated, activated.rows(), activated.cols())?;
164
165    //println!("diag: {diag:?}");
166    activated.reshape((cols, rows));
167
168    //cols rows x rows cols
169
170    let jacobian = cl_tew(
171        device,
172        &diag,
173        &device.gemm(&activated, &device.transpose(&activated)),
174        "-",
175    )?;
176
177    println!("jacobian: {jacobian:?}");
178
179    let jacobian = (jacobian, rows, cols * cols).into();
180    let mut jacobian = device.sum_rows(&jacobian);
181    jacobian.reshape((cols, cols));
182
183    // rows cols x cols cols
184    let res = device.gemm(grads, &jacobian);
185    Ok(res)
186}