Skip to main content

oximedia_gpu/
shader_params.rs

1//! Shader parameter management — param types, individual params, and uniform blocks.
2
3#![allow(dead_code)]
4#![allow(clippy::cast_precision_loss)]
5
6/// The data type of a shader parameter.
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum ParamType {
9    /// 32-bit float scalar.
10    Float,
11    /// 32-bit signed integer scalar.
12    Int,
13    /// Two-component float vector.
14    Vec2,
15    /// Three-component float vector.
16    Vec3,
17    /// Four-component float vector.
18    Vec4,
19    /// 4×4 float matrix.
20    Matrix4x4,
21    /// Texture / sampler handle.
22    Texture,
23}
24
25impl ParamType {
26    /// Returns the byte size of this parameter type in GPU memory.
27    #[must_use]
28    pub fn byte_size(&self) -> usize {
29        match self {
30            Self::Float => 4,
31            Self::Int => 4,
32            Self::Vec2 => 8,
33            Self::Vec3 => 12,
34            Self::Vec4 => 16,
35            Self::Matrix4x4 => 64,
36            // Texture handles are represented as 8-byte opaque handles.
37            Self::Texture => 8,
38        }
39    }
40}
41
42/// A single named parameter within a uniform block.
43#[derive(Debug, Clone)]
44pub struct ShaderParam {
45    /// The GLSL/WGSL name of the parameter.
46    pub name: String,
47    /// The data type.
48    pub param_type: ParamType,
49    /// Byte offset within the owning uniform block.
50    pub offset: u32,
51}
52
53impl ShaderParam {
54    /// Creates a new `ShaderParam`.
55    #[must_use]
56    pub fn new(name: impl Into<String>, param_type: ParamType, offset: u32) -> Self {
57        Self {
58            name: name.into(),
59            param_type,
60            offset,
61        }
62    }
63
64    /// The byte offset *past* the end of this parameter (exclusive end offset).
65    #[must_use]
66    pub fn end_offset(&self) -> u32 {
67        self.offset + self.param_type.byte_size() as u32
68    }
69}
70
71/// A named uniform block containing multiple [`ShaderParam`] entries.
72#[derive(Debug)]
73pub struct UniformBlock {
74    /// Ordered list of parameters in this block.
75    pub params: Vec<ShaderParam>,
76    /// Name of the uniform block (e.g. `"Globals"`).
77    pub name: String,
78}
79
80impl UniformBlock {
81    /// Creates a new empty `UniformBlock` with the given name.
82    #[must_use]
83    pub fn new(name: impl Into<String>) -> Self {
84        Self {
85            params: Vec::new(),
86            name: name.into(),
87        }
88    }
89
90    /// Appends a parameter to the block.
91    pub fn add_param(&mut self, param: ShaderParam) {
92        self.params.push(param);
93    }
94
95    /// Finds the first parameter with the given name.
96    #[must_use]
97    pub fn find(&self, name: &str) -> Option<&ShaderParam> {
98        self.params.iter().find(|p| p.name == name)
99    }
100
101    /// Returns the total size in bytes of all parameters in this block.
102    ///
103    /// This is the sum of each parameter's byte size (not the end offset of
104    /// the last parameter, since params may not be tightly packed).
105    #[must_use]
106    pub fn total_size_bytes(&self) -> u32 {
107        self.params
108            .iter()
109            .map(|p| p.param_type.byte_size() as u32)
110            .sum()
111    }
112
113    /// Returns the number of parameters in this block.
114    #[must_use]
115    pub fn param_count(&self) -> usize {
116        self.params.len()
117    }
118
119    /// Returns `true` if the total size is a multiple of 16 (std140 alignment).
120    #[must_use]
121    pub fn is_aligned(&self) -> bool {
122        self.total_size_bytes() % 16 == 0
123    }
124}
125
126// ---------------------------------------------------------------------------
127// Unit tests
128// ---------------------------------------------------------------------------
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    #[test]
135    fn test_float_byte_size() {
136        assert_eq!(ParamType::Float.byte_size(), 4);
137    }
138
139    #[test]
140    fn test_int_byte_size() {
141        assert_eq!(ParamType::Int.byte_size(), 4);
142    }
143
144    #[test]
145    fn test_vec2_byte_size() {
146        assert_eq!(ParamType::Vec2.byte_size(), 8);
147    }
148
149    #[test]
150    fn test_vec3_byte_size() {
151        assert_eq!(ParamType::Vec3.byte_size(), 12);
152    }
153
154    #[test]
155    fn test_vec4_byte_size() {
156        assert_eq!(ParamType::Vec4.byte_size(), 16);
157    }
158
159    #[test]
160    fn test_matrix4x4_byte_size() {
161        assert_eq!(ParamType::Matrix4x4.byte_size(), 64);
162    }
163
164    #[test]
165    fn test_texture_byte_size() {
166        assert_eq!(ParamType::Texture.byte_size(), 8);
167    }
168
169    #[test]
170    fn test_shader_param_end_offset() {
171        let p = ShaderParam::new("brightness", ParamType::Float, 0);
172        assert_eq!(p.end_offset(), 4);
173    }
174
175    #[test]
176    fn test_shader_param_end_offset_vec4() {
177        let p = ShaderParam::new("color", ParamType::Vec4, 16);
178        assert_eq!(p.end_offset(), 32);
179    }
180
181    #[test]
182    fn test_shader_param_end_offset_matrix() {
183        let p = ShaderParam::new("mvp", ParamType::Matrix4x4, 64);
184        assert_eq!(p.end_offset(), 128);
185    }
186
187    #[test]
188    fn test_uniform_block_add_and_count() {
189        let mut block = UniformBlock::new("Globals");
190        block.add_param(ShaderParam::new("time", ParamType::Float, 0));
191        block.add_param(ShaderParam::new("resolution", ParamType::Vec2, 4));
192        assert_eq!(block.param_count(), 2);
193    }
194
195    #[test]
196    fn test_uniform_block_find_existing() {
197        let mut block = UniformBlock::new("Params");
198        block.add_param(ShaderParam::new("gamma", ParamType::Float, 0));
199        let found = block.find("gamma");
200        assert!(found.is_some());
201        assert_eq!(found.expect("parameter should be found").offset, 0);
202    }
203
204    #[test]
205    fn test_uniform_block_find_missing() {
206        let block = UniformBlock::new("Params");
207        assert!(block.find("nonexistent").is_none());
208    }
209
210    #[test]
211    fn test_uniform_block_total_size() {
212        // Float(4) + Vec4(16) = 20
213        let mut block = UniformBlock::new("Mix");
214        block.add_param(ShaderParam::new("alpha", ParamType::Float, 0));
215        block.add_param(ShaderParam::new("tint", ParamType::Vec4, 4));
216        assert_eq!(block.total_size_bytes(), 20);
217    }
218
219    #[test]
220    fn test_uniform_block_is_aligned_true() {
221        // Vec4(16) is divisible by 16.
222        let mut block = UniformBlock::new("Aligned");
223        block.add_param(ShaderParam::new("v", ParamType::Vec4, 0));
224        assert!(block.is_aligned());
225    }
226
227    #[test]
228    fn test_uniform_block_is_aligned_false() {
229        // Float(4) is not divisible by 16.
230        let mut block = UniformBlock::new("Unaligned");
231        block.add_param(ShaderParam::new("x", ParamType::Float, 0));
232        assert!(!block.is_aligned());
233    }
234
235    #[test]
236    fn test_uniform_block_matrix_is_aligned() {
237        // Matrix4x4(64) is divisible by 16.
238        let mut block = UniformBlock::new("MVP");
239        block.add_param(ShaderParam::new("mvp", ParamType::Matrix4x4, 0));
240        assert!(block.is_aligned());
241    }
242}