Skip to main content

astrelis_render/
material.rs

1//! Material system for high-level shader parameter management.
2//!
3//! Provides a declarative API for managing shader parameters, textures, and pipeline state.
4//!
5//! # Example
6//!
7//! ```ignore
8//! use astrelis_render::*;
9//! use glam::{Vec2, Vec3, Mat4};
10//!
11//! let mut material = Material::new(shader, &renderer);
12//!
13//! // Set parameters
14//! material.set_parameter("color", MaterialParameter::Color(Color::RED));
15//! material.set_parameter("time", MaterialParameter::Float(1.5));
16//! material.set_parameter("view_proj", MaterialParameter::Matrix4(view_proj_matrix));
17//!
18//! // Set textures
19//! material.set_texture("albedo_texture", texture_handle);
20//!
21//! // Apply material to render pass
22//! material.bind(&mut pass);
23//! ```
24
25use astrelis_core::profiling::profile_function;
26
27use crate::{Color, GraphicsContext};
28use ahash::HashMap;
29use glam::{Mat4, Vec2, Vec3, Vec4};
30use std::sync::Arc;
31
32/// A material parameter value that can be bound to a shader.
33#[derive(Debug, Clone)]
34pub enum MaterialParameter {
35    /// Single float value
36    Float(f32),
37    /// 2D vector
38    Vec2(Vec2),
39    /// 3D vector
40    Vec3(Vec3),
41    /// 4D vector
42    Vec4(Vec4),
43    /// RGBA color
44    Color(Color),
45    /// 4x4 matrix
46    Matrix4(Mat4),
47    /// Array of floats
48    FloatArray(Vec<f32>),
49    /// Array of Vec2
50    Vec2Array(Vec<Vec2>),
51    /// Array of Vec3
52    Vec3Array(Vec<Vec3>),
53    /// Array of Vec4
54    Vec4Array(Vec<Vec4>),
55}
56
57impl MaterialParameter {
58    /// Convert parameter to bytes for GPU upload.
59    pub fn as_bytes(&self) -> Vec<u8> {
60        match self {
61            MaterialParameter::Float(v) => bytemuck::bytes_of(v).to_vec(),
62            MaterialParameter::Vec2(v) => bytemuck::bytes_of(v).to_vec(),
63            MaterialParameter::Vec3(v) => {
64                // Pad Vec3 to 16 bytes for alignment
65                let mut bytes = Vec::with_capacity(16);
66                bytes.extend_from_slice(bytemuck::bytes_of(v));
67                bytes.extend_from_slice(&[0u8; 4]); // padding
68                bytes
69            }
70            MaterialParameter::Vec4(v) => bytemuck::bytes_of(v).to_vec(),
71            MaterialParameter::Color(c) => bytemuck::bytes_of(c).to_vec(),
72            MaterialParameter::Matrix4(m) => bytemuck::bytes_of(m).to_vec(),
73            MaterialParameter::FloatArray(arr) => bytemuck::cast_slice(arr).to_vec(),
74            MaterialParameter::Vec2Array(arr) => bytemuck::cast_slice(arr).to_vec(),
75            MaterialParameter::Vec3Array(arr) => {
76                // Each Vec3 needs padding to 16 bytes
77                let mut bytes = Vec::with_capacity(arr.len() * 16);
78                for v in arr {
79                    bytes.extend_from_slice(bytemuck::bytes_of(v));
80                    bytes.extend_from_slice(&[0u8; 4]); // padding
81                }
82                bytes
83            }
84            MaterialParameter::Vec4Array(arr) => bytemuck::cast_slice(arr).to_vec(),
85        }
86    }
87
88    /// Get the size of the parameter in bytes (including padding).
89    pub fn size(&self) -> u64 {
90        match self {
91            MaterialParameter::Float(_) => 4,
92            MaterialParameter::Vec2(_) => 8,
93            MaterialParameter::Vec3(_) => 16, // Padded
94            MaterialParameter::Vec4(_) => 16,
95            MaterialParameter::Color(_) => 16,
96            MaterialParameter::Matrix4(_) => 64,
97            MaterialParameter::FloatArray(arr) => (arr.len() * 4) as u64,
98            MaterialParameter::Vec2Array(arr) => (arr.len() * 8) as u64,
99            MaterialParameter::Vec3Array(arr) => (arr.len() * 16) as u64, // Padded
100            MaterialParameter::Vec4Array(arr) => (arr.len() * 16) as u64,
101        }
102    }
103}
104
105/// Texture binding information for a material.
106#[derive(Debug, Clone)]
107pub struct MaterialTexture {
108    /// The texture to bind
109    pub texture: wgpu::Texture,
110    /// The texture view
111    pub view: wgpu::TextureView,
112    /// Optional sampler (if None, a default linear sampler will be used)
113    pub sampler: Option<wgpu::Sampler>,
114}
115
116/// Pipeline state configuration for a material.
117#[derive(Debug, Clone)]
118pub struct PipelineState {
119    /// Primitive topology (default: TriangleList)
120    pub topology: wgpu::PrimitiveTopology,
121    /// Cull mode (default: Some(Back))
122    pub cull_mode: Option<wgpu::Face>,
123    /// Front face winding (default: Ccw)
124    pub front_face: wgpu::FrontFace,
125    /// Polygon mode (default: Fill)
126    pub polygon_mode: wgpu::PolygonMode,
127    /// Depth test enabled (default: false)
128    pub depth_test: bool,
129    /// Depth write enabled (default: false)
130    pub depth_write: bool,
131    /// Blend mode (default: None - opaque)
132    pub blend: Option<wgpu::BlendState>,
133}
134
135impl Default for PipelineState {
136    fn default() -> Self {
137        Self {
138            topology: wgpu::PrimitiveTopology::TriangleList,
139            cull_mode: Some(wgpu::Face::Back),
140            front_face: wgpu::FrontFace::Ccw,
141            polygon_mode: wgpu::PolygonMode::Fill,
142            depth_test: false,
143            depth_write: false,
144            blend: None,
145        }
146    }
147}
148
149/// A material manages shader parameters, textures, and pipeline state.
150pub struct Material {
151    /// The shader module
152    shader: Arc<wgpu::ShaderModule>,
153    /// Named parameters
154    parameters: HashMap<String, MaterialParameter>,
155    /// Named textures
156    textures: HashMap<String, MaterialTexture>,
157    /// Pipeline state
158    pipeline_state: PipelineState,
159    /// Graphics context reference
160    context: Arc<GraphicsContext>,
161    /// Cached uniform buffer
162    uniform_buffer: Option<wgpu::Buffer>,
163    /// Cached bind group layout
164    bind_group_layout: Option<wgpu::BindGroupLayout>,
165    /// Cached bind group
166    bind_group: Option<wgpu::BindGroup>,
167    /// Dirty flag - set to true when parameters/textures change
168    dirty: bool,
169}
170
171impl Material {
172    /// Create a new material with a shader.
173    pub fn new(shader: Arc<wgpu::ShaderModule>, context: Arc<GraphicsContext>) -> Self {
174        Self {
175            shader,
176            parameters: HashMap::default(),
177            textures: HashMap::default(),
178            pipeline_state: PipelineState::default(),
179            context,
180            uniform_buffer: None,
181            bind_group_layout: None,
182            bind_group: None,
183            dirty: true,
184        }
185    }
186
187    /// Create a material from a shader source string.
188    pub fn from_source(source: &str, label: Option<&str>, context: Arc<GraphicsContext>) -> Self {
189        profile_function!();
190        let shader = context
191            .device()
192            .create_shader_module(wgpu::ShaderModuleDescriptor {
193                label,
194                source: wgpu::ShaderSource::Wgsl(source.into()),
195            });
196        Self::new(Arc::new(shader), context)
197    }
198
199    /// Set a parameter by name.
200    pub fn set_parameter(&mut self, name: impl Into<String>, value: MaterialParameter) {
201        self.parameters.insert(name.into(), value);
202        self.dirty = true;
203    }
204
205    /// Get a parameter by name.
206    pub fn get_parameter(&self, name: &str) -> Option<&MaterialParameter> {
207        self.parameters.get(name)
208    }
209
210    /// Set a texture by name.
211    pub fn set_texture(&mut self, name: impl Into<String>, texture: MaterialTexture) {
212        self.textures.insert(name.into(), texture);
213        self.dirty = true;
214    }
215
216    /// Get a texture by name.
217    pub fn get_texture(&self, name: &str) -> Option<&MaterialTexture> {
218        self.textures.get(name)
219    }
220
221    /// Set the pipeline state.
222    pub fn set_pipeline_state(&mut self, state: PipelineState) {
223        self.pipeline_state = state;
224    }
225
226    /// Get the pipeline state.
227    pub fn pipeline_state(&self) -> &PipelineState {
228        &self.pipeline_state
229    }
230
231    /// Get the shader module.
232    pub fn shader(&self) -> &wgpu::ShaderModule {
233        &self.shader
234    }
235
236    /// Update GPU resources if dirty.
237    fn update_resources(&mut self) {
238        profile_function!();
239        if !self.dirty {
240            return;
241        }
242
243        // Calculate total uniform buffer size
244        let mut uniform_size = 0u64;
245        for param in self.parameters.values() {
246            uniform_size += param.size();
247            // Add padding for alignment
248            if !uniform_size.is_multiple_of(16) {
249                uniform_size += 16 - (uniform_size % 16);
250            }
251        }
252
253        // Create or update uniform buffer
254        if uniform_size > 0 {
255            let mut uniform_data = Vec::new();
256            for param in self.parameters.values() {
257                uniform_data.extend_from_slice(&param.as_bytes());
258                // Add padding for alignment
259                let current_size = uniform_data.len() as u64;
260                if !current_size.is_multiple_of(16) {
261                    let padding = 16 - (current_size % 16);
262                    uniform_data.extend(vec![0u8; padding as usize]);
263                }
264            }
265
266            if let Some(buffer) = &self.uniform_buffer {
267                // Update existing buffer
268                self.context.queue().write_buffer(buffer, 0, &uniform_data);
269            } else {
270                // Create new buffer
271                let buffer = self
272                    .context
273                    .device()
274                    .create_buffer(&wgpu::BufferDescriptor {
275                        label: Some("Material Uniform Buffer"),
276                        size: uniform_size,
277                        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
278                        mapped_at_creation: false,
279                    });
280                self.context.queue().write_buffer(&buffer, 0, &uniform_data);
281                self.uniform_buffer = Some(buffer);
282            }
283        }
284
285        // Rebuild bind group layout and bind group
286        self.rebuild_bind_groups();
287
288        self.dirty = false;
289    }
290
291    /// Rebuild bind group layout and bind group.
292    fn rebuild_bind_groups(&mut self) {
293        let mut layout_entries = Vec::new();
294        let mut bind_entries = Vec::new();
295        let mut binding = 0u32;
296
297        // Add uniform buffer binding if present
298        if let Some(buf) = &self.uniform_buffer {
299            layout_entries.push(wgpu::BindGroupLayoutEntry {
300                binding,
301                visibility: wgpu::ShaderStages::VERTEX_FRAGMENT,
302                ty: wgpu::BindingType::Buffer {
303                    ty: wgpu::BufferBindingType::Uniform,
304                    has_dynamic_offset: false,
305                    min_binding_size: None,
306                },
307                count: None,
308            });
309
310            bind_entries.push(wgpu::BindGroupEntry {
311                binding,
312                resource: buf.as_entire_binding(),
313            });
314
315            binding += 1;
316        }
317
318        // Add texture bindings
319        for texture in self.textures.values() {
320            // Texture binding
321            layout_entries.push(wgpu::BindGroupLayoutEntry {
322                binding,
323                visibility: wgpu::ShaderStages::FRAGMENT,
324                ty: wgpu::BindingType::Texture {
325                    sample_type: wgpu::TextureSampleType::Float { filterable: true },
326                    view_dimension: wgpu::TextureViewDimension::D2,
327                    multisampled: false,
328                },
329                count: None,
330            });
331
332            bind_entries.push(wgpu::BindGroupEntry {
333                binding,
334                resource: wgpu::BindingResource::TextureView(&texture.view),
335            });
336
337            binding += 1;
338
339            // Sampler binding
340            layout_entries.push(wgpu::BindGroupLayoutEntry {
341                binding,
342                visibility: wgpu::ShaderStages::FRAGMENT,
343                ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering),
344                count: None,
345            });
346
347            // Use provided sampler or create default
348            let sampler = if let Some(ref s) = texture.sampler {
349                s
350            } else {
351                // Create a default linear sampler (this should be cached in practice)
352                // For now, we'll use a temporary one
353                // TODO: Add sampler cache to Material or GraphicsContext
354                unimplemented!("Default sampler not yet implemented - please provide sampler")
355            };
356
357            bind_entries.push(wgpu::BindGroupEntry {
358                binding,
359                resource: wgpu::BindingResource::Sampler(sampler),
360            });
361
362            binding += 1;
363        }
364
365        // Create bind group layout
366        let layout =
367            self.context
368                .device()
369                .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
370                    label: Some("Material Bind Group Layout"),
371                    entries: &layout_entries,
372                });
373
374        // Create bind group
375        let bind_group = self
376            .context
377            .device()
378            .create_bind_group(&wgpu::BindGroupDescriptor {
379                label: Some("Material Bind Group"),
380                layout: &layout,
381                entries: &bind_entries,
382            });
383
384        self.bind_group_layout = Some(layout);
385        self.bind_group = Some(bind_group);
386    }
387
388    /// Bind this material's resources to a render pass.
389    ///
390    /// This will update GPU resources if needed and set the bind group.
391    ///
392    /// # Arguments
393    ///
394    /// * `pass` - The render pass to bind to
395    /// * `bind_group_index` - The bind group index (default is usually 0)
396    pub fn bind<'a>(&'a mut self, pass: &mut wgpu::RenderPass<'a>, bind_group_index: u32) {
397        profile_function!();
398        self.update_resources();
399
400        if let Some(ref bind_group) = self.bind_group {
401            pass.set_bind_group(bind_group_index, bind_group, &[]);
402        }
403    }
404
405    /// Get the bind group layout (creates it if needed).
406    ///
407    /// This is useful when creating render pipelines.
408    pub fn bind_group_layout(&mut self) -> &wgpu::BindGroupLayout {
409        if self.dirty || self.bind_group_layout.is_none() {
410            self.update_resources();
411        }
412        self.bind_group_layout
413            .as_ref()
414            .expect("Bind group layout should be created")
415    }
416}
417
418/// Builder for creating materials with a fluent API.
419pub struct MaterialBuilder {
420    shader: Option<Arc<wgpu::ShaderModule>>,
421    parameters: HashMap<String, MaterialParameter>,
422    textures: HashMap<String, MaterialTexture>,
423    pipeline_state: PipelineState,
424    context: Arc<GraphicsContext>,
425}
426
427impl MaterialBuilder {
428    /// Create a new material builder.
429    pub fn new(context: Arc<GraphicsContext>) -> Self {
430        Self {
431            shader: None,
432            parameters: HashMap::default(),
433            textures: HashMap::default(),
434            pipeline_state: PipelineState::default(),
435            context,
436        }
437    }
438
439    /// Set the shader from a module.
440    pub fn shader(mut self, shader: Arc<wgpu::ShaderModule>) -> Self {
441        self.shader = Some(shader);
442        self
443    }
444
445    /// Set the shader from source code.
446    pub fn shader_source(mut self, source: &str, label: Option<&str>) -> Self {
447        let shader = self
448            .context
449            .device()
450            .create_shader_module(wgpu::ShaderModuleDescriptor {
451                label,
452                source: wgpu::ShaderSource::Wgsl(source.into()),
453            });
454        self.shader = Some(Arc::new(shader));
455        self
456    }
457
458    /// Set a parameter.
459    pub fn parameter(mut self, name: impl Into<String>, value: MaterialParameter) -> Self {
460        self.parameters.insert(name.into(), value);
461        self
462    }
463
464    /// Set a texture.
465    pub fn texture(mut self, name: impl Into<String>, texture: MaterialTexture) -> Self {
466        self.textures.insert(name.into(), texture);
467        self
468    }
469
470    /// Set the pipeline state.
471    pub fn pipeline_state(mut self, state: PipelineState) -> Self {
472        self.pipeline_state = state;
473        self
474    }
475
476    /// Build the material.
477    pub fn build(self) -> Material {
478        let shader = self.shader.expect("Shader is required");
479        let mut material = Material::new(shader, self.context);
480        material.parameters = self.parameters;
481        material.textures = self.textures;
482        material.pipeline_state = self.pipeline_state;
483        material.dirty = true;
484        material
485    }
486}