custos_math/ops/
diagflat.rs

1#[cfg(feature = "cuda")]
2use crate::cu_to_cpu_s;
3use crate::Matrix;
4#[cfg(feature = "cuda")]
5use custos::CUDA;
6use custos::{CDatatype, Device, MainMemory};
7
8#[cfg(feature = "cpu")]
9use custos::{cache::Cache, cpu::CPU};
10
11#[cfg(feature = "opencl")]
12use super::cl_to_cpu_s;
13#[cfg(feature = "opencl")]
14use custos::OpenCL;
15
16impl<'a, T, D: DiagflatOp<T>> Matrix<'a, T, D> {
17    pub fn diagflat(&self) -> Matrix<'a, T, D> {
18        self.device().diagflat(self)
19    }
20}
21
22pub fn diagflat<T: Copy>(a: &[T], b: &mut [T]) {
23    for (row, x) in a.iter().enumerate() {
24        b[row * a.len() + row] = *x;
25    }
26}
27
28pub trait DiagflatOp<T, D: Device = Self>: Device {
29    fn diagflat(&self, x: &Matrix<T, D>) -> Matrix<T, Self>;
30}
31
32#[cfg(feature = "cpu")]
33impl<T: Default + Copy, D: MainMemory> DiagflatOp<T, D> for CPU {
34    fn diagflat(&self, x: &Matrix<T, D>) -> Matrix<T> {
35        assert!(x.dims().0 == 1 || x.dims().1 == 1);
36        let size = x.size();
37
38        let mut out = Cache::get(self, size * size, x.node.idx);
39        diagflat(x, &mut out);
40        (out, (size, size)).into()
41    }
42}
43
44#[cfg(feature = "cuda")]
45impl<T: Copy + Default> DiagflatOp<T> for CUDA {
46    #[inline]
47    fn diagflat(&self, x: &Matrix<T, Self>) -> Matrix<T, Self> {
48        cu_to_cpu_s(self, x, |cpu, x| cpu.diagflat(&x))
49    }
50}
51
52#[cfg(feature = "opencl")]
53impl<T: CDatatype> DiagflatOp<T> for OpenCL {
54    #[inline]
55    fn diagflat(&self, x: &Matrix<T, Self>) -> Matrix<T, Self> {
56        cl_to_cpu_s(self, x, |device, x| device.diagflat(x))
57    }
58}