custos_math/ops/
diagflat.rs1#[cfg(feature = "cuda")]
2use crate::cu_to_cpu_s;
3use crate::Matrix;
4#[cfg(feature = "cuda")]
5use custos::CUDA;
6use custos::{CDatatype, Device, MainMemory};
7
8#[cfg(feature = "cpu")]
9use custos::{cache::Cache, cpu::CPU};
10
11#[cfg(feature = "opencl")]
12use super::cl_to_cpu_s;
13#[cfg(feature = "opencl")]
14use custos::OpenCL;
15
16impl<'a, T, D: DiagflatOp<T>> Matrix<'a, T, D> {
17 pub fn diagflat(&self) -> Matrix<'a, T, D> {
18 self.device().diagflat(self)
19 }
20}
21
22pub fn diagflat<T: Copy>(a: &[T], b: &mut [T]) {
23 for (row, x) in a.iter().enumerate() {
24 b[row * a.len() + row] = *x;
25 }
26}
27
28pub trait DiagflatOp<T, D: Device = Self>: Device {
29 fn diagflat(&self, x: &Matrix<T, D>) -> Matrix<T, Self>;
30}
31
32#[cfg(feature = "cpu")]
33impl<T: Default + Copy, D: MainMemory> DiagflatOp<T, D> for CPU {
34 fn diagflat(&self, x: &Matrix<T, D>) -> Matrix<T> {
35 assert!(x.dims().0 == 1 || x.dims().1 == 1);
36 let size = x.size();
37
38 let mut out = Cache::get(self, size * size, x.node.idx);
39 diagflat(x, &mut out);
40 (out, (size, size)).into()
41 }
42}
43
44#[cfg(feature = "cuda")]
45impl<T: Copy + Default> DiagflatOp<T> for CUDA {
46 #[inline]
47 fn diagflat(&self, x: &Matrix<T, Self>) -> Matrix<T, Self> {
48 cu_to_cpu_s(self, x, |cpu, x| cpu.diagflat(&x))
49 }
50}
51
52#[cfg(feature = "opencl")]
53impl<T: CDatatype> DiagflatOp<T> for OpenCL {
54 #[inline]
55 fn diagflat(&self, x: &Matrix<T, Self>) -> Matrix<T, Self> {
56 cl_to_cpu_s(self, x, |device, x| device.diagflat(x))
57 }
58}