1use crate::{
2 get_error,
3 gpu::{Device, ShaderFormat, ShaderStage, WeakDevice},
4 Error,
5};
6use std::{ffi::CStr, sync::Arc};
7use sys::gpu::{SDL_GPUShader, SDL_GPUShaderCreateInfo};
8
9struct ShaderContainer {
11 raw: *mut SDL_GPUShader,
12 device: WeakDevice,
13}
14impl Drop for ShaderContainer {
15 fn drop(&mut self) {
16 if let Some(device) = self.device.upgrade() {
17 unsafe { sys::gpu::SDL_ReleaseGPUShader(device.raw(), self.raw) }
18 }
19 }
20}
21
22#[derive(Clone)]
23pub struct Shader {
24 inner: Arc<ShaderContainer>,
25}
26impl Shader {
27 #[inline]
28 pub fn raw(&self) -> *mut SDL_GPUShader {
29 self.inner.raw
30 }
31}
32
33pub struct ShaderBuilder<'a> {
34 device: &'a Device,
35 inner: SDL_GPUShaderCreateInfo,
36}
37impl<'a> ShaderBuilder<'a> {
38 pub(super) fn new(device: &'a Device) -> Self {
39 Self {
40 device,
41 inner: Default::default(),
42 }
43 }
44
45 pub fn with_samplers(mut self, value: u32) -> Self {
46 self.inner.num_samplers = value;
47 self
48 }
49
50 pub fn with_storage_buffers(mut self, value: u32) -> Self {
51 self.inner.num_storage_buffers = value;
52 self
53 }
54
55 pub fn with_storage_textures(mut self, value: u32) -> Self {
56 self.inner.num_storage_textures = value;
57 self
58 }
59
60 pub fn with_uniform_buffers(mut self, value: u32) -> Self {
61 self.inner.num_uniform_buffers = value;
62 self
63 }
64
65 pub fn with_code(mut self, fmt: ShaderFormat, code: &'a [u8], stage: ShaderStage) -> Self {
66 self.inner.format = fmt as u32;
67 self.inner.code = code.as_ptr();
68 self.inner.code_size = code.len() as usize;
69 self.inner.stage = unsafe { std::mem::transmute(stage as u32) };
70 self
71 }
72 pub fn with_entrypoint(mut self, entry_point: &'a CStr) -> Self {
73 self.inner.entrypoint = entry_point.as_ptr();
74 self
75 }
76 pub fn build(self) -> Result<Shader, Error> {
77 let raw_shader = unsafe { sys::gpu::SDL_CreateGPUShader(self.device.raw(), &self.inner) };
78 if !raw_shader.is_null() {
79 Ok(Shader {
80 inner: Arc::new(ShaderContainer {
81 raw: raw_shader,
82 device: self.device.weak(),
83 }),
84 })
85 } else {
86 Err(get_error())
87 }
88 }
89}