custos_math/ops/nn/loss/
mse.rs

1use crate::{AdditionalOps, BaseOps, Matrix, SumOps};
2#[cfg(feature = "opencl")]
3use custos::{opencl::enqueue_kernel, OpenCL};
4use custos::{prelude::Number, CDatatype, IsShapeIndep, Shape};
5
6#[inline]
7pub fn mse<'a, T, D, S>(
8    preds: &Matrix<'a, T, D, S>,
9    targets: &Matrix<'a, T, D>,
10) -> (T, Matrix<'a, T, D>)
11where
12    T: Number,
13    D: IsShapeIndep + BaseOps<T> + SumOps<T> + AdditionalOps<T>,
14    S: Shape,
15{
16    let preds = preds.as_dims();
17    (mse_loss(preds, targets), mse_grad(preds, targets))
18}
19
20pub fn mse_loss<T, D, S>(preds: &Matrix<T, D, S>, targets: &Matrix<T, D, S>) -> T
21where
22    D: BaseOps<T, S> + SumOps<T, S>,
23    S: Shape,
24{
25    let x = preds - targets;
26    (&x * &x).mean()
27}
28
29pub fn mse_grad<'a, T, D, S>(
30    preds: &Matrix<'a, T, D, S>,
31    targets: &Matrix<'a, T, D, S>,
32) -> Matrix<'a, T, D, S>
33where
34    T: Number,
35    D: BaseOps<T, S> + SumOps<T, S> + AdditionalOps<T, S>,
36    S: Shape,
37{
38    let x = preds - targets;
39    (&x * T::two() / T::from_usize(preds.cols())) / T::from_usize(preds.rows())
40}
41
42#[cfg(feature = "opencl")]
43pub fn mse_grad_cl<'a, T: CDatatype>(
44    device: &'a OpenCL,
45    preds: &Matrix<'a, T, OpenCL>,
46    targets: &Matrix<'a, T, OpenCL>,
47) -> Matrix<'a, T, OpenCL> {
48    use custos::Device;
49
50    let src = format!(
51        "
52        __kernel void mse_grad(__global const {datatype}* preds, 
53            __global const {datatype}* targets, 
54            __global {datatype}* out,
55            const {datatype} cols, const {datatype} rows) 
56            
57        {{
58            size_t id = get_global_id(0);
59
60            {datatype} x = (preds[id] - targets[id]) * 2;
61            out[id] = (x / cols) / rows;
62        }}
63    ",
64        datatype = T::as_c_type_str()
65    );
66
67    let out: custos::Buffer<T, OpenCL> =
68        device.retrieve(preds.len(), (preds.node.idx, targets.node.idx));
69    enqueue_kernel(
70        device,
71        &src,
72        [preds.len(), 0, 0],
73        None,
74        &[
75            preds,
76            targets,
77            &out,
78            &T::from_usize(preds.cols()),
79            &T::from_usize(preds.rows()),
80        ],
81    )
82    .unwrap();
83    (out, preds.dims()).into()
84}