1use crate::backend::{Backend, BackendBuffer, BackendKernel};
5use crate::buffer::{Buffer, GpuBuf};
6use crate::dispatch::{self, DispatchConfig};
7use crate::error::{GpuError, Result};
8use crate::kernel::Kernel;
9use crate::shader;
10
11pub struct Device {
30 inner: DeviceInner,
31}
32
33enum DeviceInner {
34 #[cfg(feature = "vulkan")]
35 Vulkan(crate::backend::vulkan::VulkanBackend),
36 #[cfg(feature = "cuda")]
37 Cuda(crate::backend::cuda::CudaBackend),
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum BackendKind {
43 Vulkan,
45 Cuda,
47 }
49
50impl Device {
51 pub fn auto() -> Result<Self> {
57 #[cfg(feature = "cuda")]
58 {
59 use crate::backend::cuda::CudaBackend;
60 if let Ok(backend) = CudaBackend::create() {
61 return Ok(Self {
62 inner: DeviceInner::Cuda(backend),
63 });
64 }
65 }
66
67 #[cfg(feature = "vulkan")]
68 {
69 use crate::backend::vulkan::VulkanBackend;
70 if let Ok(backend) = VulkanBackend::create() {
71 return Ok(Self {
72 inner: DeviceInner::Vulkan(backend),
73 });
74 }
75 }
76
77 Err(GpuError::NoDevice)
78 }
79
80 pub fn with_backend(kind: BackendKind) -> Result<Self> {
82 match kind {
83 BackendKind::Vulkan => {
84 #[cfg(feature = "vulkan")]
85 {
86 use crate::backend::vulkan::VulkanBackend;
87 let backend = VulkanBackend::create()?;
88 Ok(Self {
89 inner: DeviceInner::Vulkan(backend),
90 })
91 }
92 #[cfg(not(feature = "vulkan"))]
93 {
94 Err(GpuError::BackendUnavailable(
95 "vulkan feature not enabled".into(),
96 ))
97 }
98 }
99 BackendKind::Cuda => {
100 #[cfg(feature = "cuda")]
101 {
102 use crate::backend::cuda::CudaBackend;
103 let backend = CudaBackend::create()?;
104 Ok(Self {
105 inner: DeviceInner::Cuda(backend),
106 })
107 }
108 #[cfg(not(feature = "cuda"))]
109 {
110 Err(GpuError::BackendUnavailable(
111 "cuda feature not enabled".into(),
112 ))
113 }
114 }
115 }
116 }
117
118 pub fn upload<T: bytemuck::Pod>(&self, data: &[T]) -> Result<Buffer<T>> {
120 let bytes = bytemuck::cast_slice(data);
121 let inner = self.upload_raw(bytes)?;
122 Ok(Buffer {
123 inner,
124 len: data.len(),
125 _marker: std::marker::PhantomData,
126 })
127 }
128
129 pub fn alloc<T: bytemuck::Pod>(&self, count: usize) -> Result<Buffer<T>> {
131 let size = count.checked_mul(std::mem::size_of::<T>()).ok_or_else(|| {
132 GpuError::AllocationFailed {
133 requested: u64::MAX,
134 device_max: self.memory(),
135 }
136 })? as u64;
137 let inner = self.alloc_raw(size)?;
138 Ok(Buffer {
139 inner,
140 len: count,
141 _marker: std::marker::PhantomData,
142 })
143 }
144
145 pub fn dispatch(
151 &self,
152 shader_src: &str,
153 buffers: &[&dyn GpuBuf],
154 invocations: u32,
155 ) -> Result<()> {
156 let entry = "main";
157 let compiled = shader::compile_wgsl(shader_src, entry)?;
158
159 let expected = shader::binding_count(&compiled.module);
160 let backend_bufs: Vec<&BackendBuffer> = buffers.iter().map(|b| b.raw()).collect();
161 if expected != backend_bufs.len() {
162 return Err(GpuError::BindingMismatch {
163 expected,
164 got: backend_bufs.len(),
165 });
166 }
167
168 let wg_size = dispatch::extract_workgroup_size(&compiled.module, entry);
169 let workgroups = dispatch::calc_dispatch(invocations, wg_size);
170
171 self.dispatch_spirv(&compiled.spirv, entry, &backend_bufs, workgroups, None)
172 }
173
174 pub fn dispatch_configured(
176 &self,
177 config: &DispatchConfig<'_>,
178 buffers: &[&dyn GpuBuf],
179 ) -> Result<()> {
180 let entry = config.entry_point.unwrap_or("main");
181 let compiled = shader::compile_wgsl(config.shader, entry)?;
182
183 let expected = shader::binding_count(&compiled.module);
184 let backend_bufs: Vec<&BackendBuffer> = buffers.iter().map(|b| b.raw()).collect();
185 if expected != backend_bufs.len() {
186 return Err(GpuError::BindingMismatch {
187 expected,
188 got: backend_bufs.len(),
189 });
190 }
191
192 let workgroups = config.workgroups.unwrap_or_else(|| {
193 let wg_size = dispatch::extract_workgroup_size(&compiled.module, entry);
194 dispatch::calc_dispatch(config.invocations, wg_size)
195 });
196
197 self.dispatch_spirv(
198 &compiled.spirv,
199 entry,
200 &backend_bufs,
201 workgroups,
202 config.push_constants,
203 )
204 }
205
206 pub fn compile(&self, shader_src: &str) -> Result<Kernel> {
214 self.compile_named(shader_src, "main")
215 }
216
217 pub fn compile_named(&self, shader_src: &str, entry_point: &str) -> Result<Kernel> {
219 let compiled = shader::compile_wgsl(shader_src, entry_point)?;
220 let binding_count = shader::binding_count(&compiled.module);
221 let workgroup_size = dispatch::extract_workgroup_size(&compiled.module, entry_point);
222 let push_constant_size = shader::push_constant_size(&compiled.module);
223
224 let inner = self.create_pipeline(
225 &compiled.spirv,
226 entry_point,
227 binding_count,
228 push_constant_size,
229 )?;
230
231 Ok(Kernel {
232 inner,
233 binding_count,
234 workgroup_size,
235 entry_point: entry_point.to_string(),
236 })
237 }
238
239 pub fn run(&self, kernel: &Kernel, buffers: &[&dyn GpuBuf], invocations: u32) -> Result<()> {
245 let backend_bufs: Vec<&BackendBuffer> = buffers.iter().map(|b| b.raw()).collect();
246 if kernel.binding_count != backend_bufs.len() {
247 return Err(GpuError::BindingMismatch {
248 expected: kernel.binding_count,
249 got: backend_bufs.len(),
250 });
251 }
252
253 let workgroups = dispatch::calc_dispatch(invocations, kernel.workgroup_size);
254 self.run_pipeline(kernel, &backend_bufs, workgroups, None)
255 }
256
257 pub fn run_with_push_constants(
259 &self,
260 kernel: &Kernel,
261 buffers: &[&dyn GpuBuf],
262 invocations: u32,
263 push_constants: &[u8],
264 ) -> Result<()> {
265 let backend_bufs: Vec<&BackendBuffer> = buffers.iter().map(|b| b.raw()).collect();
266 if kernel.binding_count != backend_bufs.len() {
267 return Err(GpuError::BindingMismatch {
268 expected: kernel.binding_count,
269 got: backend_bufs.len(),
270 });
271 }
272
273 let workgroups = dispatch::calc_dispatch(invocations, kernel.workgroup_size);
274 self.run_pipeline(kernel, &backend_bufs, workgroups, Some(push_constants))
275 }
276
277 pub fn run_configured(
282 &self,
283 kernel: &Kernel,
284 buffers: &[&dyn GpuBuf],
285 workgroups: [u32; 3],
286 push_constants: Option<&[u8]>,
287 ) -> Result<()> {
288 let backend_bufs: Vec<&BackendBuffer> = buffers.iter().map(|b| b.raw()).collect();
289 if kernel.binding_count != backend_bufs.len() {
290 return Err(GpuError::BindingMismatch {
291 expected: kernel.binding_count,
292 got: backend_bufs.len(),
293 });
294 }
295
296 self.run_pipeline(kernel, &backend_bufs, workgroups, push_constants)
297 }
298
299 pub fn copy_buffer<T: bytemuck::Pod>(&self, src: &Buffer<T>) -> Result<Buffer<T>> {
304 let size = src.byte_size();
305 let inner = self.copy_buffer_raw(&src.inner, size)?;
306 Ok(Buffer {
307 inner,
308 len: src.len,
309 _marker: std::marker::PhantomData,
310 })
311 }
312
313 pub fn batch(&self) -> Result<crate::batch::Batch> {
318 match &self.inner {
319 #[cfg(feature = "vulkan")]
320 DeviceInner::Vulkan(b) => {
321 let vk_batch = b.begin_batch()?;
322 Ok(crate::batch::Batch::new_vulkan(vk_batch))
323 }
324 #[cfg(feature = "cuda")]
325 DeviceInner::Cuda(b) => {
326 let cuda_batch = b.begin_batch()?;
327 Ok(crate::batch::Batch::new_cuda(cuda_batch))
328 }
329 }
330 }
331
332 pub fn name(&self) -> &str {
334 match &self.inner {
335 #[cfg(feature = "vulkan")]
336 DeviceInner::Vulkan(b) => b.device_name(),
337 #[cfg(feature = "cuda")]
338 DeviceInner::Cuda(b) => b.device_name(),
339 }
340 }
341
342 pub fn memory(&self) -> u64 {
344 match &self.inner {
345 #[cfg(feature = "vulkan")]
346 DeviceInner::Vulkan(b) => b.device_memory(),
347 #[cfg(feature = "cuda")]
348 DeviceInner::Cuda(b) => b.device_memory(),
349 }
350 }
351
352 pub fn subgroup_size(&self) -> u32 {
357 match &self.inner {
358 #[cfg(feature = "vulkan")]
359 DeviceInner::Vulkan(b) => b.subgroup_size(),
360 #[cfg(feature = "cuda")]
361 DeviceInner::Cuda(b) => b.subgroup_size(),
362 }
363 }
364
365 pub const fn backend_kind(&self) -> BackendKind {
367 match &self.inner {
368 #[cfg(feature = "vulkan")]
369 DeviceInner::Vulkan(_) => BackendKind::Vulkan,
370 #[cfg(feature = "cuda")]
371 DeviceInner::Cuda(_) => BackendKind::Cuda,
372 }
373 }
374
375 fn upload_raw(&self, data: &[u8]) -> Result<BackendBuffer> {
378 match &self.inner {
379 #[cfg(feature = "vulkan")]
380 DeviceInner::Vulkan(b) => {
381 let buf = b.upload(data)?;
382 Ok(BackendBuffer::Vulkan(buf))
383 }
384 #[cfg(feature = "cuda")]
385 DeviceInner::Cuda(b) => {
386 let buf = b.upload(data)?;
387 Ok(BackendBuffer::Cuda(buf))
388 }
389 }
390 }
391
392 fn copy_buffer_raw(&self, src: &BackendBuffer, size: u64) -> Result<BackendBuffer> {
393 match &self.inner {
394 #[cfg(feature = "vulkan")]
395 DeviceInner::Vulkan(b) => {
396 #[allow(irrefutable_let_patterns)]
397 let BackendBuffer::Vulkan(vk_src) = src
398 else {
399 return Err(GpuError::BackendUnavailable(
400 "buffer/backend mismatch: expected Vulkan buffer".into(),
401 ));
402 };
403 let buf = b.copy_buffer(vk_src, size)?;
404 Ok(BackendBuffer::Vulkan(buf))
405 }
406 #[cfg(feature = "cuda")]
407 DeviceInner::Cuda(b) => {
408 #[allow(irrefutable_let_patterns)]
409 let BackendBuffer::Cuda(cuda_src) = src
410 else {
411 return Err(GpuError::BackendUnavailable(
412 "buffer/backend mismatch: expected CUDA buffer".into(),
413 ));
414 };
415 let buf = b.copy_buffer(cuda_src, size)?;
416 Ok(BackendBuffer::Cuda(buf))
417 }
418 }
419 }
420
421 fn alloc_raw(&self, size: u64) -> Result<BackendBuffer> {
422 match &self.inner {
423 #[cfg(feature = "vulkan")]
424 DeviceInner::Vulkan(b) => {
425 let buf = b.alloc(size)?;
426 Ok(BackendBuffer::Vulkan(buf))
427 }
428 #[cfg(feature = "cuda")]
429 DeviceInner::Cuda(b) => {
430 let buf = b.alloc(size)?;
431 Ok(BackendBuffer::Cuda(buf))
432 }
433 }
434 }
435
436 fn dispatch_spirv(
437 &self,
438 spirv: &[u32],
439 entry_point: &str,
440 buffers: &[&BackendBuffer],
441 workgroups: [u32; 3],
442 push_constants: Option<&[u8]>,
443 ) -> Result<()> {
444 match &self.inner {
445 #[cfg(feature = "vulkan")]
446 DeviceInner::Vulkan(b) => {
447 let vk_bufs: Vec<&crate::backend::vulkan::VulkanBuffer> = buffers
448 .iter()
449 .map(|buf| match buf {
450 BackendBuffer::Vulkan(vb) => Ok(vb),
451 #[cfg(feature = "cuda")]
452 _ => Err(GpuError::BackendUnavailable(
453 "buffer/backend mismatch: expected Vulkan buffer".into(),
454 )),
455 })
456 .collect::<Result<Vec<_>>>()?;
457 b.dispatch(spirv, entry_point, &vk_bufs, workgroups, push_constants)
458 }
459 #[cfg(feature = "cuda")]
460 DeviceInner::Cuda(b) => {
461 let cuda_bufs: Vec<&crate::backend::cuda::CudaBuffer> = buffers
462 .iter()
463 .map(|buf| match buf {
464 BackendBuffer::Cuda(cb) => Ok(cb),
465 #[cfg(feature = "vulkan")]
466 _ => Err(GpuError::BackendUnavailable(
467 "buffer/backend mismatch: expected CUDA buffer".into(),
468 )),
469 })
470 .collect::<Result<Vec<_>>>()?;
471 b.dispatch(spirv, entry_point, &cuda_bufs, workgroups, push_constants)
472 }
473 }
474 }
475
476 fn create_pipeline(
477 &self,
478 spirv: &[u32],
479 entry_point: &str,
480 binding_count: usize,
481 push_constant_size: u32,
482 ) -> Result<BackendKernel> {
483 match &self.inner {
484 #[cfg(feature = "vulkan")]
485 DeviceInner::Vulkan(b) => {
486 let kernel =
487 b.create_pipeline(spirv, entry_point, binding_count, push_constant_size)?;
488 Ok(BackendKernel::Vulkan(kernel))
489 }
490 #[cfg(feature = "cuda")]
491 DeviceInner::Cuda(b) => {
492 let kernel =
493 b.create_pipeline(spirv, entry_point, binding_count, push_constant_size)?;
494 Ok(BackendKernel::Cuda(kernel))
495 }
496 }
497 }
498
499 fn run_pipeline(
500 &self,
501 kernel: &Kernel,
502 buffers: &[&BackendBuffer],
503 workgroups: [u32; 3],
504 push_constants: Option<&[u8]>,
505 ) -> Result<()> {
506 match &self.inner {
507 #[cfg(feature = "vulkan")]
508 DeviceInner::Vulkan(b) => {
509 #[allow(irrefutable_let_patterns)]
510 let BackendKernel::Vulkan(vk_kernel) = &kernel.inner
511 else {
512 return Err(GpuError::BackendUnavailable(
513 "kernel was not compiled for Vulkan".into(),
514 ));
515 };
516 let vk_bufs: Vec<&crate::backend::vulkan::VulkanBuffer> = buffers
517 .iter()
518 .map(|buf| match buf {
519 BackendBuffer::Vulkan(vb) => Ok(vb),
520 #[cfg(feature = "cuda")]
521 _ => Err(GpuError::BackendUnavailable(
522 "buffer/backend mismatch: expected Vulkan buffer".into(),
523 )),
524 })
525 .collect::<Result<Vec<_>>>()?;
526 b.dispatch_pipeline(vk_kernel, &vk_bufs, workgroups, push_constants)
527 }
528 #[cfg(feature = "cuda")]
529 DeviceInner::Cuda(b) => {
530 let BackendKernel::Cuda(cuda_kernel) = &kernel.inner else {
531 return Err(GpuError::BackendUnavailable(
532 "kernel was not compiled for CUDA".into(),
533 ));
534 };
535 let cuda_bufs: Vec<&crate::backend::cuda::CudaBuffer> = buffers
536 .iter()
537 .map(|buf| match buf {
538 BackendBuffer::Cuda(cb) => Ok(cb),
539 #[cfg(feature = "vulkan")]
540 _ => Err(GpuError::BackendUnavailable(
541 "buffer/backend mismatch: expected CUDA buffer".into(),
542 )),
543 })
544 .collect::<Result<Vec<_>>>()?;
545 b.dispatch_pipeline(cuda_kernel, &cuda_bufs, workgroups, push_constants)
546 }
547 }
548 }
549}
550
551#[cfg(feature = "cuda")]
554impl Device {
555 pub fn compile_cuda(
569 &self,
570 source: &str,
571 entry_point: &str,
572 binding_count: usize,
573 workgroup_size: [u32; 3],
574 ) -> Result<Kernel> {
575 match &self.inner {
576 DeviceInner::Cuda(b) => {
577 let block_dim = (workgroup_size[0], workgroup_size[1], workgroup_size[2]);
578 let cuda_kernel = b.compile_cuda(source, entry_point, block_dim)?;
579 Ok(Kernel {
580 inner: BackendKernel::Cuda(cuda_kernel),
581 binding_count,
582 workgroup_size,
583 entry_point: entry_point.to_string(),
584 })
585 }
586 #[cfg(feature = "vulkan")]
587 _ => Err(GpuError::BackendUnavailable(
588 "compile_cuda requires CUDA backend".into(),
589 )),
590 }
591 }
592
593 #[allow(clippy::many_single_char_names)]
600 pub fn cublas_matmul(
601 &self,
602 a: &Buffer<f32>,
603 b: &Buffer<f32>,
604 c: &mut Buffer<f32>,
605 m: u32,
606 n: u32,
607 k: u32,
608 ) -> Result<()> {
609 match &self.inner {
610 DeviceInner::Cuda(backend) => {
611 let BackendBuffer::Cuda(a_buf) = &a.inner else {
612 return Err(GpuError::BackendUnavailable(
613 "buffer not from CUDA backend".into(),
614 ));
615 };
616 let BackendBuffer::Cuda(b_buf) = &b.inner else {
617 return Err(GpuError::BackendUnavailable(
618 "buffer not from CUDA backend".into(),
619 ));
620 };
621 let BackendBuffer::Cuda(c_buf) = &mut c.inner else {
622 return Err(GpuError::BackendUnavailable(
623 "buffer not from CUDA backend".into(),
624 ));
625 };
626 backend.cublas_matmul(a_buf, b_buf, c_buf, m, n, k)
627 }
628 #[cfg(feature = "vulkan")]
629 _ => Err(GpuError::BackendUnavailable(
630 "cublas_matmul requires CUDA backend".into(),
631 )),
632 }
633 }
634}
635
636impl std::fmt::Debug for Device {
637 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
638 f.debug_struct("Device")
639 .field("name", &self.name())
640 .field("memory_mb", &(self.memory() / (1024 * 1024)))
641 .finish()
642 }
643}