Skip to main content

optic_render/handles/
shader.rs

1use optic_core::OpticError;
2use optic_core::{OpticErrorKind, OpticResult};
3use cgmath::{Matrix, Matrix2, Matrix3, Matrix4, Vector2, Vector3, Vector4};
4use gl::types::GLint;
5use std::ffi::CString;
6use std::ptr;
7
8use crate::handles::{StorageBuffer, Texture2D};
9use crate::GL;
10
11/// A texture or storage-buffer binding slot (0–15).
12///
13/// Provides named variants for readability at call sites.
14#[derive(Clone, Debug)]
15pub enum Slot {
16    S0, S1, S2, S3, S4, S5, S6, S7,
17    S8, S9, S10, S11, S12, S13, S14, S15,
18}
19
20impl Slot {
21    /// Returns the integer index of this slot (0–15).
22    pub fn as_index(&self) -> usize {
23        match self {
24            Slot::S0 => 0, Slot::S1 => 1, Slot::S2 => 2, Slot::S3 => 3,
25            Slot::S4 => 4, Slot::S5 => 5, Slot::S6 => 6, Slot::S7 => 7,
26            Slot::S8 => 8, Slot::S9 => 9, Slot::S10 => 10, Slot::S11 => 11,
27            Slot::S12 => 12, Slot::S13 => 13, Slot::S14 => 14, Slot::S15 => 15,
28        }
29    }
30    /// Returns the total number of available slots (16).
31    pub fn total_slots() -> usize { 16 }
32}
33
34/// Work-group dimensions for compute shader dispatch.
35///
36/// Used by [`Shader::compute`] to call `glDispatchCompute`.
37#[derive(Clone, Debug)]
38pub struct Workers {
39    pub group_x: u32,
40    pub group_y: u32,
41    pub group_z: u32,
42}
43
44impl Workers {
45    /// Creates a `Workers` with all groups set to 0 (no dispatch).
46    pub fn empty() -> Self { Self { group_x: 0, group_y: 0, group_z: 0 } }
47    /// Creates a `Workers` with all groups set to 1.
48    pub fn one() -> Self { Self { group_x: 1, group_y: 1, group_z: 1 } }
49    /// Sets all three work-group dimensions at once.
50    pub fn set_groups(&mut self, x: u32, y: u32, z: u32) {
51        self.set_group_x(x); self.set_group_y(y); self.set_group_z(z);
52    }
53    /// Returns the work-group dimensions as a tuple `(x, y, z)`.
54    pub fn groups(&self) -> (u32, u32, u32) { (self.group_x, self.group_y, self.group_z) }
55    /// Returns the X work-group size.
56    pub fn group_x(&self) -> u32 { self.group_x }
57    /// Returns the Y work-group size.
58    pub fn group_y(&self) -> u32 { self.group_y }
59    /// Returns the Z work-group size.
60    pub fn group_z(&self) -> u32 { self.group_z }
61    /// Sets the X work-group size.
62    pub fn set_group_x(&mut self, x: u32) { self.group_x = x; }
63    /// Sets the Y work-group size.
64    pub fn set_group_y(&mut self, y: u32) { self.group_y = y; }
65    /// Sets the Z work-group size.
66    pub fn set_group_z(&mut self, z: u32) { self.group_z = z; }
67}
68
69/// A handle to an OpenGL shader program.
70///
71/// Supports both pipeline (vertex+fragment) and compute shaders.
72/// Manages texture and storage-buffer bindings for automatic binding
73/// during rendering or compute dispatch.
74///
75/// # Uniform setters
76///
77/// The `set_*` family of methods sets uniform variables by name. They will
78/// panic if the uniform does not exist (use [`uniform_location`](Shader::uniform_location)
79/// for optional lookups).
80#[derive(Clone, Debug)]
81pub struct Shader {
82    pub workers: Workers,
83    pub id: u32,
84    pub is_compute: bool,
85    pub tex_ids: Vec<Option<u32>>,
86    pub sbo_ids: Vec<Option<u32>>,
87}
88
89impl Shader {
90    /// Wraps an existing GL program ID.
91    ///
92    /// `is_compute` controls whether textures are bound via `BindImageTexture`
93    /// (compute) or `glActiveTexture` / `glBindTexture` (render).
94    pub fn new(id: u32, is_compute: bool) -> Self {
95        Self {
96            workers: Workers::empty(),
97            id,
98            is_compute,
99            tex_ids: vec![None; Slot::total_slots()],
100            sbo_ids: vec![None; Slot::total_slots()],
101        }
102    }
103
104    /// Attaches a texture to the first available (empty) texture slot.
105    pub fn attach_tex(&mut self, tex: &Texture2D) {
106        for slot in self.tex_ids.iter_mut() {
107            if slot.is_none() {
108                *slot = Some(tex.id);
109                break;
110            }
111        }
112    }
113
114    /// Attaches a storage buffer to the first available (empty) SSBO slot.
115    pub fn attach_sbo(&mut self, sbo: &StorageBuffer) {
116        for slot in self.sbo_ids.iter_mut() {
117            if slot.is_none() {
118                *slot = Some(sbo.id);
119                break;
120            }
121        }
122    }
123
124    /// Binds a texture to a specific slot.
125    pub fn set_tex_at_slot(&mut self, tex: &Texture2D, slot: Slot) {
126        self.tex_ids[slot.as_index()] = Some(tex.id);
127    }
128
129    /// Binds a storage buffer to a specific slot.
130    pub fn set_sbo_at_slot(&mut self, sbo: &StorageBuffer, slot: Slot) {
131        self.sbo_ids[slot.as_index()] = Some(sbo.id);
132    }
133
134    /// Deletes the underlying GL program.
135    pub fn delete(self) { delete_program(self.id); }
136
137    /// Binds this shader program (`glUseProgram`).
138    pub fn bind(&self) { unsafe { gl::UseProgram(self.id); } }
139    /// Unbinds the current shader (binds program 0).
140    pub fn unbind(&self) { unsafe { gl::UseProgram(0); } }
141
142    /// Dispatches compute with the currently bound textures and SSBOs.
143    ///
144    /// Calls `glDispatchCompute(workers)` followed by a memory barrier for
145    /// shader image access and shader storage access.
146    pub fn compute(&self) {
147        self.bind();
148        self.bind_textures();
149        self.bind_storages();
150        let (x, y, z) = self.workers.groups();
151        unsafe {
152            gl::DispatchCompute(x, y, z);
153            gl::MemoryBarrier(
154                gl::SHADER_IMAGE_ACCESS_BARRIER_BIT | gl::SHADER_STORAGE_BARRIER_BIT,
155            );
156        }
157    }
158
159    /// Looks up a uniform location by name, returning `None` if not found.
160    pub fn uniform_location(&self, name: &str) -> Option<u32> {
161        unsafe {
162            let c_name = CString::new(name).unwrap();
163            let loc = gl::GetUniformLocation(self.id, c_name.as_ptr());
164            if loc == -1 { None } else { Some(loc as u32) }
165        }
166    }
167
168    /// Looks up a uniform location — panics if not found.
169    fn uni_loc(&self, name: &str) -> GLint {
170        unsafe {
171            let c_name = CString::new(name).unwrap();
172            let loc = gl::GetUniformLocation(self.id, c_name.as_ptr());
173            if loc == -1 {
174                panic!("uniform '{name}' does not exist in shader {}", self.id);
175            }
176            loc
177        }
178    }
179
180    /// Returns all (slot, tex_id) pairs for currently bound textures.
181    pub fn texture_binds(&self) -> Vec<(u32, u32)> {
182        self.tex_ids.iter().enumerate()
183            .filter_map(|(slot, id)| id.map(|tid| (slot as u32, tid)))
184            .collect()
185    }
186    /// Returns all (slot, sbo_id) pairs for currently bound storage buffers.
187    pub fn storage_binds(&self) -> Vec<(u32, u32)> {
188        self.sbo_ids.iter().enumerate()
189            .filter_map(|(slot, id)| id.map(|sid| (slot as u32, sid)))
190            .collect()
191    }
192    /// Binds all attached textures (image uniforms for compute, sampler2D for pipeline).
193    pub fn bind_textures(&self) {
194        for (slot, tex_id) in self.tex_ids.iter().enumerate() {
195            if let Some(id) = tex_id {
196                if self.is_compute {
197                    unsafe {
198                        gl::BindImageTexture(
199                            slot as u32, *id, 0, gl::FALSE, 0, gl::READ_WRITE, gl::RGBA8,
200                        );
201                    }
202                } else {
203                    GL::bind_texture_at(*id, slot as u32);
204                }
205            }
206        }
207    }
208
209    /// Binds all attached storage buffers to their slots.
210    pub fn bind_storages(&self) {
211        for (slot, sbo_id) in self.sbo_ids.iter().enumerate() {
212            if let Some(id) = sbo_id {
213                unsafe {
214                    gl::BindBufferBase(gl::SHADER_STORAGE_BUFFER, slot as u32, *id);
215                }
216            }
217        }
218    }
219
220    /// Sets an `int` uniform.
221    pub fn set_i32(&self, name: &str, v: i32) {
222        unsafe { gl::Uniform1i(self.uni_loc(name), v); }
223    }
224    /// Sets a `uint` uniform.
225    pub fn set_u32(&self, name: &str, v: u32) {
226        unsafe { gl::Uniform1ui(self.uni_loc(name), v); }
227    }
228    /// Sets a `float` uniform.
229    pub fn set_f32(&self, name: &str, v: f32) {
230        unsafe { gl::Uniform1f(self.uni_loc(name), v); }
231    }
232    /// Sets a `vec2` uniform.
233    pub fn set_vec2_f32(&self, name: &str, v: Vector2<f32>) {
234        unsafe { gl::Uniform2f(self.uni_loc(name), v.x, v.y); }
235    }
236    /// Sets a `vec3` uniform.
237    pub fn set_vec3_f32(&self, name: &str, v: Vector3<f32>) {
238        unsafe { gl::Uniform3f(self.uni_loc(name), v.x, v.y, v.z); }
239    }
240    /// Sets a `vec4` uniform.
241    pub fn set_vec4_f32(&self, name: &str, v: Vector4<f32>) {
242        unsafe { gl::Uniform4f(self.uni_loc(name), v.x, v.y, v.z, v.w); }
243    }
244    /// Sets an `ivec2` uniform.
245    pub fn set_vec2_i32(&self, name: &str, v: Vector2<i32>) {
246        unsafe { gl::Uniform2i(self.uni_loc(name), v.x, v.y); }
247    }
248    /// Sets an `ivec3` uniform.
249    pub fn set_vec3_i32(&self, name: &str, v: Vector3<i32>) {
250        unsafe { gl::Uniform3i(self.uni_loc(name), v.x, v.y, v.z); }
251    }
252    /// Sets an `ivec4` uniform.
253    pub fn set_vec4_i32(&self, name: &str, v: Vector4<i32>) {
254        unsafe { gl::Uniform4i(self.uni_loc(name), v.x, v.y, v.z, v.w); }
255    }
256    /// Sets a `uvec2` uniform.
257    pub fn set_vec2_u32(&self, name: &str, v: Vector2<u32>) {
258        unsafe { gl::Uniform2ui(self.uni_loc(name), v.x, v.y); }
259    }
260    /// Sets a `uvec3` uniform.
261    pub fn set_vec3_u32(&self, name: &str, v: Vector3<u32>) {
262        unsafe { gl::Uniform3ui(self.uni_loc(name), v.x, v.y, v.z); }
263    }
264    /// Sets a `uvec4` uniform.
265    pub fn set_vec4_u32(&self, name: &str, v: Vector4<u32>) {
266        unsafe { gl::Uniform4ui(self.uni_loc(name), v.x, v.y, v.z, v.w); }
267    }
268    /// Sets a `mat2` uniform.
269    pub fn set_m2_f32(&self, name: &str, m: Matrix2<f32>) {
270        unsafe { gl::UniformMatrix2fv(self.uni_loc(name), 1, gl::FALSE, m.as_ptr()); }
271    }
272    /// Sets a `mat3` uniform.
273    pub fn set_m3_f32(&self, name: &str, m: Matrix3<f32>) {
274        unsafe { gl::UniformMatrix3fv(self.uni_loc(name), 1, gl::FALSE, m.as_ptr()); }
275    }
276    /// Sets a `mat4` uniform.
277    pub fn set_m4_f32(&self, name: &str, m: Matrix4<f32>) {
278        unsafe { gl::UniformMatrix4fv(self.uni_loc(name), 1, gl::FALSE, m.as_ptr()); }
279    }
280}
281
282/// Compiles a single GLSL shader stage.
283///
284/// `shader_type` should be one of `gl::VERTEX_SHADER`, `gl::FRAGMENT_SHADER`,
285/// or `gl::COMPUTE_SHADER`. Returns the shader object ID on success.
286pub fn compile_shader(src: &str, shader_type: gl::types::GLenum) -> OpticResult<u32> {
287    let c_src = CString::new(src)
288        .map_err(|e| OpticError::new(OpticErrorKind::Shader, &format!("null byte in shader source: {e}")))?;
289
290    unsafe {
291        let id = gl::CreateShader(shader_type);
292        gl::ShaderSource(id, 1, &c_src.as_ptr(), ptr::null());
293        gl::CompileShader(id);
294
295        let mut success = gl::FALSE as GLint;
296        gl::GetShaderiv(id, gl::COMPILE_STATUS, &mut success);
297        if success != gl::TRUE as GLint {
298            let mut log_len = 0;
299            gl::GetShaderiv(id, gl::INFO_LOG_LENGTH, &mut log_len);
300            let mut log = vec![0u8; log_len.max(1) as usize - 1];
301            gl::GetShaderInfoLog(
302                id, log_len, ptr::null_mut(),
303                log.as_mut_ptr() as *mut gl::types::GLchar,
304            );
305            let msg = String::from_utf8_lossy(&log).to_string();
306            gl::DeleteShader(id);
307            return Err(OpticError::new(OpticErrorKind::Shader, &msg));
308        }
309        Ok(id)
310    }
311}
312
313/// Links a vertex + fragment shader pair into a GL program.
314///
315/// Both shader stages are compiled and linked. Returns the program ID on success.
316pub fn link_program(vert: &str, frag: &str) -> OpticResult<u32> {
317    let v_id = compile_shader(vert, gl::VERTEX_SHADER)?;
318    let f_id = compile_shader(frag, gl::FRAGMENT_SHADER)?;
319
320    unsafe {
321        let program = gl::CreateProgram();
322        gl::AttachShader(program, v_id);
323        gl::AttachShader(program, f_id);
324        gl::LinkProgram(program);
325
326        let mut success = gl::FALSE as GLint;
327        gl::GetProgramiv(program, gl::LINK_STATUS, &mut success);
328        gl::DeleteShader(v_id);
329        gl::DeleteShader(f_id);
330
331        if success != gl::TRUE as GLint {
332            let mut log_len = 0;
333            gl::GetProgramiv(program, gl::INFO_LOG_LENGTH, &mut log_len);
334            let mut log = vec![0u8; log_len.max(1) as usize - 1];
335            gl::GetProgramInfoLog(
336                program, log_len, ptr::null_mut(),
337                log.as_mut_ptr() as *mut gl::types::GLchar,
338            );
339            let msg = String::from_utf8_lossy(&log).to_string();
340            gl::DeleteProgram(program);
341            return Err(OpticError::new(OpticErrorKind::Shader, &msg));
342        }
343        Ok(program)
344    }
345}
346
347/// Links a compute shader source into a GL program.
348///
349/// Returns the program ID on success.
350pub fn link_compute_program(src: &str) -> OpticResult<u32> {
351    let c_id = compile_shader(src, gl::COMPUTE_SHADER)?;
352
353    unsafe {
354        let program = gl::CreateProgram();
355        gl::AttachShader(program, c_id);
356        gl::LinkProgram(program);
357
358        let mut success = gl::FALSE as GLint;
359        gl::GetProgramiv(program, gl::LINK_STATUS, &mut success);
360        gl::DeleteShader(c_id);
361
362        if success != gl::TRUE as GLint {
363            let mut log_len = 0;
364            gl::GetProgramiv(program, gl::INFO_LOG_LENGTH, &mut log_len);
365            let mut log = vec![0u8; log_len.max(1) as usize - 1];
366            gl::GetProgramInfoLog(
367                program, log_len, ptr::null_mut(),
368                log.as_mut_ptr() as *mut gl::types::GLchar,
369            );
370            let msg = String::from_utf8_lossy(&log).to_string();
371            gl::DeleteProgram(program);
372            return Err(OpticError::new(OpticErrorKind::Shader, &msg));
373        }
374        Ok(program)
375    }
376}
377
378/// Deletes a GL program object.
379pub fn delete_program(id: u32) {
380    unsafe { gl::DeleteProgram(id); }
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386
387    #[test]
388    fn slot_as_index() {
389        assert_eq!(Slot::S0.as_index(), 0);
390        assert_eq!(Slot::S7.as_index(), 7);
391        assert_eq!(Slot::S15.as_index(), 15);
392    }
393
394    #[test]
395    fn slot_total_slots() {
396        assert_eq!(Slot::total_slots(), 16);
397    }
398
399    #[test]
400    fn workers_empty() {
401        let w = Workers::empty();
402        assert_eq!(w.groups(), (0, 0, 0));
403    }
404
405    #[test]
406    fn workers_set_groups() {
407        let mut w = Workers::empty();
408        w.set_groups(10, 1, 1);
409        assert_eq!(w.groups(), (10, 1, 1));
410    }
411
412    #[test]
413    fn workers_set_individual() {
414        let mut w = Workers::empty();
415        w.set_group_x(8);
416        w.set_group_y(4);
417        w.set_group_z(2);
418        assert_eq!(w.groups(), (8, 4, 2));
419    }
420
421    #[test]
422    fn shader_new() {
423        let s = Shader::new(42, false);
424        assert_eq!(s.id, 42);
425        assert!(!s.is_compute);
426        assert_eq!(s.tex_ids.len(), 16);
427        assert_eq!(s.sbo_ids.len(), 16);
428    }
429
430    #[test]
431    fn shader_new_compute() {
432        let s = Shader::new(99, true);
433        assert!(s.is_compute);
434    }
435
436    #[test]
437    fn shader_workers_association() {
438        let mut s = Shader::new(1, true);
439        s.workers.set_groups(16, 1, 1);
440        assert_eq!(s.workers.groups(), (16, 1, 1));
441    }
442}