Skip to main content

astrelis_render/
material.rs

1//! Material system for high-level shader parameter management.
2//!
3//! Provides a declarative API for managing shader parameters, textures, and pipeline state.
4//!
5//! # Example
6//!
7//! ```ignore
8//! use astrelis_render::*;
9//! use glam::{Vec2, Vec3, Mat4};
10//!
11//! let mut material = Material::new(shader, &renderer);
12//!
13//! // Set parameters
14//! material.set_parameter("color", MaterialParameter::Color(Color::RED));
15//! material.set_parameter("time", MaterialParameter::Float(1.5));
16//! material.set_parameter("view_proj", MaterialParameter::Matrix4(view_proj_matrix));
17//!
18//! // Set textures
19//! material.set_texture("albedo_texture", texture_handle);
20//!
21//! // Apply material to render pass
22//! material.bind(&mut pass);
23//! ```
24
25use astrelis_core::profiling::profile_function;
26
27use crate::{Color, GraphicsContext};
28use ahash::HashMap;
29use glam::{Mat4, Vec2, Vec3, Vec4};
30use std::sync::Arc;
31
32/// A material parameter value that can be bound to a shader.
33#[derive(Debug, Clone)]
34pub enum MaterialParameter {
35    /// Single float value
36    Float(f32),
37    /// 2D vector
38    Vec2(Vec2),
39    /// 3D vector
40    Vec3(Vec3),
41    /// 4D vector
42    Vec4(Vec4),
43    /// RGBA color
44    Color(Color),
45    /// 4x4 matrix
46    Matrix4(Mat4),
47    /// Array of floats
48    FloatArray(Vec<f32>),
49    /// Array of Vec2
50    Vec2Array(Vec<Vec2>),
51    /// Array of Vec3
52    Vec3Array(Vec<Vec3>),
53    /// Array of Vec4
54    Vec4Array(Vec<Vec4>),
55}
56
57impl MaterialParameter {
58    /// Convert parameter to bytes for GPU upload.
59    pub fn as_bytes(&self) -> Vec<u8> {
60        match self {
61            MaterialParameter::Float(v) => bytemuck::bytes_of(v).to_vec(),
62            MaterialParameter::Vec2(v) => bytemuck::bytes_of(v).to_vec(),
63            MaterialParameter::Vec3(v) => {
64                // Pad Vec3 to 16 bytes for alignment
65                let mut bytes = Vec::with_capacity(16);
66                bytes.extend_from_slice(bytemuck::bytes_of(v));
67                bytes.extend_from_slice(&[0u8; 4]); // padding
68                bytes
69            }
70            MaterialParameter::Vec4(v) => bytemuck::bytes_of(v).to_vec(),
71            MaterialParameter::Color(c) => bytemuck::bytes_of(c).to_vec(),
72            MaterialParameter::Matrix4(m) => bytemuck::bytes_of(m).to_vec(),
73            MaterialParameter::FloatArray(arr) => bytemuck::cast_slice(arr).to_vec(),
74            MaterialParameter::Vec2Array(arr) => bytemuck::cast_slice(arr).to_vec(),
75            MaterialParameter::Vec3Array(arr) => {
76                // Each Vec3 needs padding to 16 bytes
77                let mut bytes = Vec::with_capacity(arr.len() * 16);
78                for v in arr {
79                    bytes.extend_from_slice(bytemuck::bytes_of(v));
80                    bytes.extend_from_slice(&[0u8; 4]); // padding
81                }
82                bytes
83            }
84            MaterialParameter::Vec4Array(arr) => bytemuck::cast_slice(arr).to_vec(),
85        }
86    }
87
88    /// Get the size of the parameter in bytes (including padding).
89    pub fn size(&self) -> u64 {
90        match self {
91            MaterialParameter::Float(_) => 4,
92            MaterialParameter::Vec2(_) => 8,
93            MaterialParameter::Vec3(_) => 16, // Padded
94            MaterialParameter::Vec4(_) => 16,
95            MaterialParameter::Color(_) => 16,
96            MaterialParameter::Matrix4(_) => 64,
97            MaterialParameter::FloatArray(arr) => (arr.len() * 4) as u64,
98            MaterialParameter::Vec2Array(arr) => (arr.len() * 8) as u64,
99            MaterialParameter::Vec3Array(arr) => (arr.len() * 16) as u64, // Padded
100            MaterialParameter::Vec4Array(arr) => (arr.len() * 16) as u64,
101        }
102    }
103}
104
105/// Texture binding information for a material.
106#[derive(Debug, Clone)]
107pub struct MaterialTexture {
108    /// The texture to bind
109    pub texture: wgpu::Texture,
110    /// The texture view
111    pub view: wgpu::TextureView,
112    /// Optional sampler (if None, a default linear sampler will be used)
113    pub sampler: Option<wgpu::Sampler>,
114}
115
116/// Pipeline state configuration for a material.
117#[derive(Debug, Clone)]
118pub struct PipelineState {
119    /// Primitive topology (default: TriangleList)
120    pub topology: wgpu::PrimitiveTopology,
121    /// Cull mode (default: Some(Back))
122    pub cull_mode: Option<wgpu::Face>,
123    /// Front face winding (default: Ccw)
124    pub front_face: wgpu::FrontFace,
125    /// Polygon mode (default: Fill)
126    pub polygon_mode: wgpu::PolygonMode,
127    /// Depth test enabled (default: false)
128    pub depth_test: bool,
129    /// Depth write enabled (default: false)
130    pub depth_write: bool,
131    /// Blend mode (default: None - opaque)
132    pub blend: Option<wgpu::BlendState>,
133}
134
135impl Default for PipelineState {
136    fn default() -> Self {
137        Self {
138            topology: wgpu::PrimitiveTopology::TriangleList,
139            cull_mode: Some(wgpu::Face::Back),
140            front_face: wgpu::FrontFace::Ccw,
141            polygon_mode: wgpu::PolygonMode::Fill,
142            depth_test: false,
143            depth_write: false,
144            blend: None,
145        }
146    }
147}
148
149/// A material manages shader parameters, textures, and pipeline state.
150pub struct Material {
151    /// The shader module
152    shader: Arc<wgpu::ShaderModule>,
153    /// Named parameters
154    parameters: HashMap<String, MaterialParameter>,
155    /// Named textures
156    textures: HashMap<String, MaterialTexture>,
157    /// Pipeline state
158    pipeline_state: PipelineState,
159    /// Graphics context reference
160    context: Arc<GraphicsContext>,
161    /// Cached uniform buffer
162    uniform_buffer: Option<wgpu::Buffer>,
163    /// Cached bind group layout
164    bind_group_layout: Option<wgpu::BindGroupLayout>,
165    /// Cached bind group
166    bind_group: Option<wgpu::BindGroup>,
167    /// Dirty flag - set to true when parameters/textures change
168    dirty: bool,
169}
170
171impl Material {
172    /// Create a new material with a shader.
173    pub fn new(shader: Arc<wgpu::ShaderModule>, context: Arc<GraphicsContext>) -> Self {
174        Self {
175            shader,
176            parameters: HashMap::default(),
177            textures: HashMap::default(),
178            pipeline_state: PipelineState::default(),
179            context,
180            uniform_buffer: None,
181            bind_group_layout: None,
182            bind_group: None,
183            dirty: true,
184        }
185    }
186
187    /// Create a material from a shader source string.
188    pub fn from_source(
189        source: &str,
190        label: Option<&str>,
191        context: Arc<GraphicsContext>,
192    ) -> Self {
193        profile_function!();
194        let shader = context
195            .device()
196            .create_shader_module(wgpu::ShaderModuleDescriptor {
197                label,
198                source: wgpu::ShaderSource::Wgsl(source.into()),
199            });
200        Self::new(Arc::new(shader), context)
201    }
202
203    /// Set a parameter by name.
204    pub fn set_parameter(&mut self, name: impl Into<String>, value: MaterialParameter) {
205        self.parameters.insert(name.into(), value);
206        self.dirty = true;
207    }
208
209    /// Get a parameter by name.
210    pub fn get_parameter(&self, name: &str) -> Option<&MaterialParameter> {
211        self.parameters.get(name)
212    }
213
214    /// Set a texture by name.
215    pub fn set_texture(&mut self, name: impl Into<String>, texture: MaterialTexture) {
216        self.textures.insert(name.into(), texture);
217        self.dirty = true;
218    }
219
220    /// Get a texture by name.
221    pub fn get_texture(&self, name: &str) -> Option<&MaterialTexture> {
222        self.textures.get(name)
223    }
224
225    /// Set the pipeline state.
226    pub fn set_pipeline_state(&mut self, state: PipelineState) {
227        self.pipeline_state = state;
228    }
229
230    /// Get the pipeline state.
231    pub fn pipeline_state(&self) -> &PipelineState {
232        &self.pipeline_state
233    }
234
235    /// Get the shader module.
236    pub fn shader(&self) -> &wgpu::ShaderModule {
237        &self.shader
238    }
239
240    /// Update GPU resources if dirty.
241    fn update_resources(&mut self) {
242        profile_function!();
243        if !self.dirty {
244            return;
245        }
246
247        // Calculate total uniform buffer size
248        let mut uniform_size = 0u64;
249        for param in self.parameters.values() {
250            uniform_size += param.size();
251            // Add padding for alignment
252            if !uniform_size.is_multiple_of(16) {
253                uniform_size += 16 - (uniform_size % 16);
254            }
255        }
256
257        // Create or update uniform buffer
258        if uniform_size > 0 {
259            let mut uniform_data = Vec::new();
260            for param in self.parameters.values() {
261                uniform_data.extend_from_slice(&param.as_bytes());
262                // Add padding for alignment
263                let current_size = uniform_data.len() as u64;
264                if !current_size.is_multiple_of(16) {
265                    let padding = 16 - (current_size % 16);
266                    uniform_data.extend(vec![0u8; padding as usize]);
267                }
268            }
269
270            if let Some(buffer) = &self.uniform_buffer {
271                // Update existing buffer
272                self.context.queue().write_buffer(buffer, 0, &uniform_data);
273            } else {
274                // Create new buffer
275                let buffer = self
276                    .context
277                    .device()
278                    .create_buffer(&wgpu::BufferDescriptor {
279                        label: Some("Material Uniform Buffer"),
280                        size: uniform_size,
281                        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
282                        mapped_at_creation: false,
283                    });
284                self.context.queue().write_buffer(&buffer, 0, &uniform_data);
285                self.uniform_buffer = Some(buffer);
286            }
287        }
288
289        // Rebuild bind group layout and bind group
290        self.rebuild_bind_groups();
291
292        self.dirty = false;
293    }
294
295    /// Rebuild bind group layout and bind group.
296    fn rebuild_bind_groups(&mut self) {
297        let mut layout_entries = Vec::new();
298        let mut bind_entries = Vec::new();
299        let mut binding = 0u32;
300
301        // Add uniform buffer binding if present
302        if self.uniform_buffer.is_some() {
303            layout_entries.push(wgpu::BindGroupLayoutEntry {
304                binding,
305                visibility: wgpu::ShaderStages::VERTEX_FRAGMENT,
306                ty: wgpu::BindingType::Buffer {
307                    ty: wgpu::BufferBindingType::Uniform,
308                    has_dynamic_offset: false,
309                    min_binding_size: None,
310                },
311                count: None,
312            });
313
314            bind_entries.push(wgpu::BindGroupEntry {
315                binding,
316                resource: self.uniform_buffer.as_ref().unwrap().as_entire_binding(),
317            });
318
319            binding += 1;
320        }
321
322        // Add texture bindings
323        for texture in self.textures.values() {
324            // Texture binding
325            layout_entries.push(wgpu::BindGroupLayoutEntry {
326                binding,
327                visibility: wgpu::ShaderStages::FRAGMENT,
328                ty: wgpu::BindingType::Texture {
329                    sample_type: wgpu::TextureSampleType::Float { filterable: true },
330                    view_dimension: wgpu::TextureViewDimension::D2,
331                    multisampled: false,
332                },
333                count: None,
334            });
335
336            bind_entries.push(wgpu::BindGroupEntry {
337                binding,
338                resource: wgpu::BindingResource::TextureView(&texture.view),
339            });
340
341            binding += 1;
342
343            // Sampler binding
344            layout_entries.push(wgpu::BindGroupLayoutEntry {
345                binding,
346                visibility: wgpu::ShaderStages::FRAGMENT,
347                ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering),
348                count: None,
349            });
350
351            // Use provided sampler or create default
352            let sampler = if let Some(ref s) = texture.sampler {
353                s
354            } else {
355                // Create a default linear sampler (this should be cached in practice)
356                // For now, we'll use a temporary one
357                // TODO: Add sampler cache to Material or GraphicsContext
358                unimplemented!("Default sampler not yet implemented - please provide sampler")
359            };
360
361            bind_entries.push(wgpu::BindGroupEntry {
362                binding,
363                resource: wgpu::BindingResource::Sampler(sampler),
364            });
365
366            binding += 1;
367        }
368
369        // Create bind group layout
370        let layout = self
371            .context
372            .device()
373            .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
374                label: Some("Material Bind Group Layout"),
375                entries: &layout_entries,
376            });
377
378        // Create bind group
379        let bind_group = self
380            .context
381            .device()
382            .create_bind_group(&wgpu::BindGroupDescriptor {
383                label: Some("Material Bind Group"),
384                layout: &layout,
385                entries: &bind_entries,
386            });
387
388        self.bind_group_layout = Some(layout);
389        self.bind_group = Some(bind_group);
390    }
391
392    /// Bind this material's resources to a render pass.
393    ///
394    /// This will update GPU resources if needed and set the bind group.
395    ///
396    /// # Arguments
397    ///
398    /// * `pass` - The render pass to bind to
399    /// * `bind_group_index` - The bind group index (default is usually 0)
400    pub fn bind<'a>(&'a mut self, pass: &mut wgpu::RenderPass<'a>, bind_group_index: u32) {
401        profile_function!();
402        self.update_resources();
403
404        if let Some(ref bind_group) = self.bind_group {
405            pass.set_bind_group(bind_group_index, bind_group, &[]);
406        }
407    }
408
409    /// Get the bind group layout (creates it if needed).
410    ///
411    /// This is useful when creating render pipelines.
412    pub fn bind_group_layout(&mut self) -> &wgpu::BindGroupLayout {
413        if self.dirty || self.bind_group_layout.is_none() {
414            self.update_resources();
415        }
416        self.bind_group_layout
417            .as_ref()
418            .expect("Bind group layout should be created")
419    }
420}
421
422/// Builder for creating materials with a fluent API.
423pub struct MaterialBuilder {
424    shader: Option<Arc<wgpu::ShaderModule>>,
425    parameters: HashMap<String, MaterialParameter>,
426    textures: HashMap<String, MaterialTexture>,
427    pipeline_state: PipelineState,
428    context: Arc<GraphicsContext>,
429}
430
431impl MaterialBuilder {
432    /// Create a new material builder.
433    pub fn new(context: Arc<GraphicsContext>) -> Self {
434        Self {
435            shader: None,
436            parameters: HashMap::default(),
437            textures: HashMap::default(),
438            pipeline_state: PipelineState::default(),
439            context,
440        }
441    }
442
443    /// Set the shader from a module.
444    pub fn shader(mut self, shader: Arc<wgpu::ShaderModule>) -> Self {
445        self.shader = Some(shader);
446        self
447    }
448
449    /// Set the shader from source code.
450    pub fn shader_source(mut self, source: &str, label: Option<&str>) -> Self {
451        let shader = self
452            .context
453            .device()
454            .create_shader_module(wgpu::ShaderModuleDescriptor {
455                label,
456                source: wgpu::ShaderSource::Wgsl(source.into()),
457            });
458        self.shader = Some(Arc::new(shader));
459        self
460    }
461
462    /// Set a parameter.
463    pub fn parameter(mut self, name: impl Into<String>, value: MaterialParameter) -> Self {
464        self.parameters.insert(name.into(), value);
465        self
466    }
467
468    /// Set a texture.
469    pub fn texture(mut self, name: impl Into<String>, texture: MaterialTexture) -> Self {
470        self.textures.insert(name.into(), texture);
471        self
472    }
473
474    /// Set the pipeline state.
475    pub fn pipeline_state(mut self, state: PipelineState) -> Self {
476        self.pipeline_state = state;
477        self
478    }
479
480    /// Build the material.
481    pub fn build(self) -> Material {
482        let shader = self.shader.expect("Shader is required");
483        let mut material = Material::new(shader, self.context);
484        material.parameters = self.parameters;
485        material.textures = self.textures;
486        material.pipeline_state = self.pipeline_state;
487        material.dirty = true;
488        material
489    }
490}