1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
use crate::shapes::{Shape, Unit};
use crate::tensor::cpu::{Cpu, CpuError, NdIndex};
use crate::tensor::{DeviceStorage, HasErr, Tensor};

use cudarc::{
    cublas::{result::CublasError, CudaBlas},
    driver::{CudaDevice, CudaSlice, CudaStream, DeviceSlice, DriverError},
};
use std::{sync::Arc, vec::Vec};

/// A Cuda device that enables constructing tensors on GPUs
/// & running GPU kernels.
#[derive(Clone, Debug)]
pub struct Cuda {
    pub(crate) cpu: Cpu,
    pub(crate) dev: Arc<CudaDevice>,
    pub(crate) blas: Arc<CudaBlas>,
    /// A second stream for kernels to optionally execute on.
    pub(crate) par_stream: Arc<CudaStream>,
}

#[derive(Debug)]
pub enum CudaError {
    Blas(CublasError),
    Driver(DriverError),
    Cpu(CpuError),
}

impl From<CpuError> for CudaError {
    fn from(value: CpuError) -> Self {
        Self::Cpu(value)
    }
}

impl From<CublasError> for CudaError {
    fn from(value: CublasError) -> Self {
        Self::Blas(value)
    }
}

impl From<DriverError> for CudaError {
    fn from(value: DriverError) -> Self {
        Self::Driver(value)
    }
}

impl Default for Cuda {
    fn default() -> Self {
        Self::seed_from_u64(0)
    }
}

impl Cuda {
    /// Constructs rng with the given seed.
    pub fn seed_from_u64(seed: u64) -> Self {
        Self::try_seed_from_u64(seed).unwrap()
    }

    /// Constructs rng with the given seed.
    pub fn try_seed_from_u64(seed: u64) -> Result<Self, CudaError> {
        Self::try_build(0, seed)
    }

    /// Constructs with the given seed & device ordinal
    pub fn try_build(ordinal: usize, seed: u64) -> Result<Self, CudaError> {
        let cpu = Cpu::seed_from_u64(seed);
        let dev = CudaDevice::new(ordinal)?;
        let blas = Arc::new(CudaBlas::new(dev.clone())?);
        let par_stream = Arc::new(dev.fork_default_stream()?);
        Ok(Self {
            cpu,
            dev,
            blas,
            par_stream,
        })
    }

    /// Block until kernels finish processing. Useful for benchmarking.
    ///
    /// Examples:
    /// ```rust
    /// # use dfdx::prelude::*;
    /// let dev: Cuda = Default::default();
    /// let a = dev.tensor([1., 2., 3.]);
    /// let _b = a.square();
    /// dev.synchronize().unwrap(); // blocks until square kernel finishes.
    /// ```
    pub fn synchronize(&self) -> Result<(), CudaError> {
        self.dev.synchronize().map_err(CudaError::from)
    }
}

impl std::fmt::Display for CudaError {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        write!(f, "{self:?}")
    }
}

impl HasErr for Cuda {
    type Err = CudaError;
}

impl DeviceStorage for Cuda {
    type Vec<E: Unit> = CudaSlice<E>;

    fn try_alloc_grad<E: Unit>(&self, other: &Self::Vec<E>) -> Result<Self::Vec<E>, Self::Err> {
        let grad = self.dev.alloc_zeros(other.len())?;
        Ok(grad)
    }

    fn random_u64(&self) -> u64 {
        self.cpu.random_u64()
    }

    fn tensor_to_vec<S: Shape, E: Unit, T>(&self, tensor: &Tensor<S, E, Self, T>) -> Vec<E> {
        let buf: Vec<E> = tensor.data.try_clone().unwrap().try_into().unwrap();
        debug_assert_eq!(buf.len(), tensor.data.len());
        let mut idx = NdIndex::new(tensor.shape, tensor.strides);
        let mut contiguous = Vec::with_capacity(tensor.shape.num_elements());
        while let Some(i) = idx.next() {
            contiguous.push(buf[i]);
        }
        contiguous
    }
}