1use crate::{AdditionalOps, BaseOps, Matrix, SumOps};
2#[cfg(feature = "opencl")]
3use custos::{opencl::enqueue_kernel, OpenCL};
4use custos::{prelude::Number, CDatatype, IsShapeIndep, Shape};
5
6#[inline]
7pub fn mse<'a, T, D, S>(
8 preds: &Matrix<'a, T, D, S>,
9 targets: &Matrix<'a, T, D>,
10) -> (T, Matrix<'a, T, D>)
11where
12 T: Number,
13 D: IsShapeIndep + BaseOps<T> + SumOps<T> + AdditionalOps<T>,
14 S: Shape,
15{
16 let preds = preds.as_dims();
17 (mse_loss(preds, targets), mse_grad(preds, targets))
18}
19
20pub fn mse_loss<T, D, S>(preds: &Matrix<T, D, S>, targets: &Matrix<T, D, S>) -> T
21where
22 D: BaseOps<T, S> + SumOps<T, S>,
23 S: Shape,
24{
25 let x = preds - targets;
26 (&x * &x).mean()
27}
28
29pub fn mse_grad<'a, T, D, S>(
30 preds: &Matrix<'a, T, D, S>,
31 targets: &Matrix<'a, T, D, S>,
32) -> Matrix<'a, T, D, S>
33where
34 T: Number,
35 D: BaseOps<T, S> + SumOps<T, S> + AdditionalOps<T, S>,
36 S: Shape,
37{
38 let x = preds - targets;
39 (&x * T::two() / T::from_usize(preds.cols())) / T::from_usize(preds.rows())
40}
41
42#[cfg(feature = "opencl")]
43pub fn mse_grad_cl<'a, T: CDatatype>(
44 device: &'a OpenCL,
45 preds: &Matrix<'a, T, OpenCL>,
46 targets: &Matrix<'a, T, OpenCL>,
47) -> Matrix<'a, T, OpenCL> {
48 use custos::Device;
49
50 let src = format!(
51 "
52 __kernel void mse_grad(__global const {datatype}* preds,
53 __global const {datatype}* targets,
54 __global {datatype}* out,
55 const {datatype} cols, const {datatype} rows)
56
57 {{
58 size_t id = get_global_id(0);
59
60 {datatype} x = (preds[id] - targets[id]) * 2;
61 out[id] = (x / cols) / rows;
62 }}
63 ",
64 datatype = T::as_c_type_str()
65 );
66
67 let out: custos::Buffer<T, OpenCL> =
68 device.retrieve(preds.len(), (preds.node.idx, targets.node.idx));
69 enqueue_kernel(
70 device,
71 &src,
72 [preds.len(), 0, 0],
73 None,
74 &[
75 preds,
76 targets,
77 &out,
78 &T::from_usize(preds.cols()),
79 &T::from_usize(preds.rows()),
80 ],
81 )
82 .unwrap();
83 (out, preds.dims()).into()
84}