1use crate::Matrix;
2use custos::{number::Number, CDatatype, Device, MainMemory, CPU};
3
4#[cfg(feature = "cpu")]
5use custos::cache::Cache;
6
7#[cfg(feature = "cuda")]
8use crate::{cu_to_cpu_s, cu_to_cpu_scalar};
9#[cfg(feature = "cuda")]
10use custos::CUDA;
11
12#[cfg(feature = "opencl")]
13use super::{cl_to_cpu_s, cl_to_cpu_scalar};
14#[cfg(feature = "opencl")]
15use custos::OpenCL;
16
17impl<'a, T, D: MaxOps<T>> Matrix<'a, T, D> {
18 #[inline]
19 pub fn max(&self) -> T {
20 self.device().max(self)
21 }
22
23 #[inline]
24 pub fn max_rows(&self) -> Matrix<'a, T, D> {
25 self.device().max_rows(self)
26 }
27
28 #[inline]
29 pub fn max_cols(&self) -> Matrix<'a, T, D> {
30 self.device().max_cols(self)
31 }
32}
33
34pub trait MaxOps<T, D: Device = Self>: Device {
35 fn max(&self, x: &Matrix<T, D>) -> T;
36 fn max_rows(&self, x: &Matrix<T, D>) -> Matrix<T, Self>;
37 fn max_cols(&self, x: &Matrix<T, D>) -> Matrix<T, Self>;
38}
39
40#[cfg(feature = "cpu")]
42impl<T: Copy + PartialOrd, D: MainMemory> MaxOps<T, D> for CPU {
43 fn max(&self, x: &Matrix<T, D>) -> T {
44 let mut max = x[0];
45
46 for value in x.iter() {
47 if *value > max {
48 max = *value;
49 }
50 }
51 max
52 }
53
54 fn max_rows(&self, x: &Matrix<T, D>) -> Matrix<T> {
55 let mut out = Cache::get(self, x.cols(), x.node.idx);
56
57 let data = x.as_slice();
58 let max_rows = out.as_mut_slice();
59
60 max_rows.copy_from_slice(&data[..max_rows.len()]);
61
62 for idx in 0..x.rows() {
63 let index = idx * x.cols();
64 let row = &data[index..index + x.cols()];
65
66 for (i, data) in row.iter().enumerate() {
67 if data > &max_rows[i] {
68 max_rows[i] = *data;
69 }
70 }
71 }
72 (out, 1, x.cols()).into()
73 }
74
75 fn max_cols(&self, x: &Matrix<T, D>) -> Matrix<T> {
76 let data = x.as_slice();
77 let mut y = Cache::get(self, x.rows(), x.node.idx);
78
79 let max_cols = y.as_mut_slice();
80
81 for (idx, max_cols_val) in max_cols.iter_mut().enumerate().take(x.rows()) {
82 let index = idx * x.cols();
83 let row = &data[index..index + x.cols()];
84
85 let mut max = row[0];
86
87 for data in row {
88 if data > &max {
89 max = *data;
90 }
91 }
92 *max_cols_val = max;
93 }
94 (y, x.rows(), 1).into()
95 }
96}
97
98#[cfg(feature = "opencl")]
99impl<T: CDatatype> MaxOps<T> for OpenCL {
100 fn max(&self, x: &Matrix<T, Self>) -> T {
101 cl_to_cpu_scalar(self, x, |device, x| device.max(x))
102 }
103
104 fn max_rows(&self, x: &Matrix<T, Self>) -> Matrix<T, Self> {
105 cl_to_cpu_s(self, x, |device, x| device.max_rows(x))
106 }
107
108 fn max_cols(&self, x: &Matrix<T, Self>) -> Matrix<T, Self> {
109 cl_to_cpu_s(self, x, |device, x| device.max_cols(x))
110 }
111}
112
113#[cfg(feature = "cuda")]
114impl<T: Number> MaxOps<T> for CUDA {
115 fn max(&self, x: &Matrix<T, CUDA>) -> T {
116 cu_to_cpu_scalar(x, |cpu, x| cpu.max(&x))
117 }
118
119 fn max_rows(&self, x: &Matrix<T, CUDA>) -> Matrix<T, CUDA> {
120 cu_to_cpu_s(self, x, |cpu, x| cpu.max_rows(&x))
121 }
122
123 fn max_cols(&self, x: &Matrix<T, CUDA>) -> Matrix<T, CUDA> {
124 cu_to_cpu_s(self, x, |cpu, x| cpu.max_cols(&x))
125 }
126}