1#[cfg(feature = "cuda")]
7use std::sync::Arc;
8
9#[cfg(feature = "cuda")]
10use cudarc::cublas::CudaBlas;
11#[cfg(feature = "cuda")]
12use cudarc::driver::{CudaContext, CudaStream};
13
14#[cfg(not(feature = "cuda"))]
15use crate::error::GpuError;
16use crate::error::GpuResult;
17
18#[cfg(feature = "cuda")]
25pub struct GpuDevice {
26 ctx: Arc<CudaContext>,
27 stream: Arc<CudaStream>,
28 blas: CudaBlas,
29 ordinal: usize,
30}
31
32#[cfg(feature = "cuda")]
33impl GpuDevice {
34 pub fn new(ordinal: usize) -> GpuResult<Self> {
35 let ctx = CudaContext::new(ordinal)?;
36 let stream = ctx.default_stream();
37 let blas = CudaBlas::new(stream.clone())?;
38 Ok(Self {
39 ctx,
40 stream,
41 blas,
42 ordinal,
43 })
44 }
45
46 pub fn fork_for_capture(parent: &GpuDevice) -> GpuResult<Self> {
50 let stream = parent.stream.fork()?;
51 let blas = CudaBlas::new(stream.clone())?;
52 Ok(Self {
53 ctx: Arc::clone(&parent.ctx),
54 stream,
55 blas,
56 ordinal: parent.ordinal,
57 })
58 }
59
60 #[inline]
61 pub fn context(&self) -> &Arc<CudaContext> {
62 &self.ctx
63 }
64
65 #[inline]
70 pub fn default_stream(&self) -> &Arc<CudaStream> {
71 &self.stream
72 }
73
74 #[inline]
80 pub fn stream(&self) -> Arc<CudaStream> {
81 crate::stream::current_stream_or_default(self)
82 }
83
84 #[inline]
86 pub fn blas(&self) -> &CudaBlas {
87 &self.blas
88 }
89
90 #[inline]
91 pub fn ordinal(&self) -> usize {
92 self.ordinal
93 }
94}
95
96#[cfg(feature = "cuda")]
97impl Clone for GpuDevice {
98 fn clone(&self) -> Self {
99 let blas =
100 CudaBlas::new(self.stream.clone()).expect("CudaBlas::new failed in GpuDevice::clone");
101 Self {
102 ctx: Arc::clone(&self.ctx),
103 stream: Arc::clone(&self.stream),
104 blas,
105 ordinal: self.ordinal,
106 }
107 }
108}
109
110#[cfg(feature = "cuda")]
111impl std::fmt::Debug for GpuDevice {
112 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113 f.debug_struct("GpuDevice")
114 .field("ordinal", &self.ordinal)
115 .finish_non_exhaustive()
116 }
117}
118
119#[cfg(not(feature = "cuda"))]
127#[derive(Clone, Debug)]
128pub struct GpuDevice {
129 ordinal: usize,
130}
131
132#[cfg(not(feature = "cuda"))]
133impl GpuDevice {
134 pub fn new(ordinal: usize) -> GpuResult<Self> {
136 let _ = ordinal;
137 Err(GpuError::NoCudaFeature)
138 }
139
140 #[inline]
142 pub fn ordinal(&self) -> usize {
143 self.ordinal
144 }
145}