1use crate::error::CudaResult;
4use crate::ffi;
5
6#[derive(Debug)]
8pub struct CudaStream {
9 handle: ffi::cudaStream_t,
10 is_default: bool,
11}
12
13impl CudaStream {
14 pub fn new() -> CudaResult<Self> {
16 #[cfg(feature = "cuda")]
17 {
18 let mut handle: ffi::cudaStream_t = std::ptr::null_mut();
19
20 unsafe {
21 let err = ffi::cudaStreamCreate(&mut handle);
22 if err != 0 {
23 return Err(CudaError::DriverError(err));
24 }
25 }
26
27 Ok(CudaStream {
28 handle,
29 is_default: false,
30 })
31 }
32
33 #[cfg(not(feature = "cuda"))]
34 {
35 Ok(CudaStream {
36 handle: std::ptr::null_mut(),
37 is_default: false,
38 })
39 }
40 }
41
42 pub fn default_stream() -> Self {
44 CudaStream {
45 handle: std::ptr::null_mut(),
46 is_default: true,
47 }
48 }
49
50 pub fn handle(&self) -> ffi::cudaStream_t {
52 self.handle
53 }
54
55 pub fn synchronize(&self) -> CudaResult<()> {
57 #[cfg(feature = "cuda")]
58 unsafe {
59 let err = ffi::cudaStreamSynchronize(self.handle);
60 if err != 0 {
61 return Err(CudaError::SyncError);
62 }
63 }
64 Ok(())
65 }
66
67 pub fn is_complete(&self) -> CudaResult<bool> {
69 #[cfg(feature = "cuda")]
70 unsafe {
71 let err = ffi::cudaStreamQuery(self.handle);
72 if err == 0 {
73 return Ok(true);
74 } else if err == 600 { return Ok(false);
76 } else {
77 return Err(CudaError::DriverError(err));
78 }
79 }
80
81 #[cfg(not(feature = "cuda"))]
82 Ok(true)
83 }
84
85 pub fn wait_event(&self, _event: &CudaEvent) -> CudaResult<()> {
87 #[cfg(feature = "cuda")]
88 unsafe {
89 self.synchronize()?;
92 }
93 Ok(())
94 }
95}
96
97impl Default for CudaStream {
98 fn default() -> Self {
99 Self::default_stream()
100 }
101}
102
103impl Drop for CudaStream {
104 fn drop(&mut self) {
105 if !self.is_default && !self.handle.is_null() {
106 #[cfg(feature = "cuda")]
107 unsafe {
108 let _ = ffi::cudaStreamDestroy(self.handle);
109 }
110 }
111 }
112}
113
114unsafe impl Send for CudaStream {}
115unsafe impl Sync for CudaStream {}
116
117#[derive(Debug)]
119pub struct CudaEvent {
120 handle: ffi::cudaEvent_t,
121}
122
123impl CudaEvent {
124 pub fn new() -> CudaResult<Self> {
126 #[cfg(feature = "cuda")]
127 {
128 let mut handle: ffi::cudaEvent_t = std::ptr::null_mut();
129
130 unsafe {
131 let err = ffi::cudaEventCreate(&mut handle);
132 if err != 0 {
133 return Err(CudaError::DriverError(err));
134 }
135 }
136
137 Ok(CudaEvent { handle })
138 }
139
140 #[cfg(not(feature = "cuda"))]
141 {
142 Ok(CudaEvent {
143 handle: std::ptr::null_mut(),
144 })
145 }
146 }
147
148 pub fn record(&self, _stream: &CudaStream) -> CudaResult<()> {
150 #[cfg(feature = "cuda")]
151 unsafe {
152 let err = ffi::cudaEventRecord(self.handle, _stream.handle());
153 if err != 0 {
154 return Err(CudaError::DriverError(err));
155 }
156 }
157 Ok(())
158 }
159
160 pub fn synchronize(&self) -> CudaResult<()> {
162 #[cfg(feature = "cuda")]
163 unsafe {
164 let err = ffi::cudaEventSynchronize(self.handle);
165 if err != 0 {
166 return Err(CudaError::SyncError);
167 }
168 }
169 Ok(())
170 }
171
172 pub fn elapsed_time(_start: &CudaEvent, _end: &CudaEvent) -> CudaResult<f32> {
174 #[cfg(feature = "cuda")]
175 {
176 let mut ms: f32 = 0.0;
177
178 unsafe {
179 let err = ffi::cudaEventElapsedTime(&mut ms, _start.handle, _end.handle);
180 if err != 0 {
181 return Err(CudaError::DriverError(err));
182 }
183 }
184
185 Ok(ms)
186 }
187
188 #[cfg(not(feature = "cuda"))]
189 Ok(0.0)
190 }
191}
192
193impl Default for CudaEvent {
194 fn default() -> Self {
195 Self::new().unwrap_or(CudaEvent {
196 handle: std::ptr::null_mut(),
197 })
198 }
199}
200
201impl Drop for CudaEvent {
202 fn drop(&mut self) {
203 if !self.handle.is_null() {
204 #[cfg(feature = "cuda")]
205 unsafe {
206 let _ = ffi::cudaEventDestroy(self.handle);
207 }
208 }
209 }
210}
211
212unsafe impl Send for CudaEvent {}
213unsafe impl Sync for CudaEvent {}
214
215pub struct CudaTimer {
217 start: CudaEvent,
218 stop: CudaEvent,
219 stream: CudaStream,
220}
221
222impl CudaTimer {
223 pub fn new(stream: CudaStream) -> CudaResult<Self> {
224 Ok(CudaTimer {
225 start: CudaEvent::new()?,
226 stop: CudaEvent::new()?,
227 stream,
228 })
229 }
230
231 pub fn start(&self) -> CudaResult<()> {
232 self.start.record(&self.stream)
233 }
234
235 pub fn stop(&self) -> CudaResult<()> {
236 self.stop.record(&self.stream)
237 }
238
239 pub fn elapsed_ms(&self) -> CudaResult<f32> {
240 self.stop.synchronize()?;
241 CudaEvent::elapsed_time(&self.start, &self.stop)
242 }
243}