astrelis_render/
material.rs

1//! Material system for high-level shader parameter management.
2//!
3//! Provides a declarative API for managing shader parameters, textures, and pipeline state.
4//!
5//! # Example
6//!
7//! ```ignore
8//! use astrelis_render::*;
9//! use glam::{Vec2, Vec3, Mat4};
10//!
11//! let mut material = Material::new(shader, &renderer);
12//!
13//! // Set parameters
14//! material.set_parameter("color", MaterialParameter::Color(Color::RED));
15//! material.set_parameter("time", MaterialParameter::Float(1.5));
16//! material.set_parameter("view_proj", MaterialParameter::Matrix4(view_proj_matrix));
17//!
18//! // Set textures
19//! material.set_texture("albedo_texture", texture_handle);
20//!
21//! // Apply material to render pass
22//! material.bind(&mut pass);
23//! ```
24
25use crate::{Color, GraphicsContext};
26use ahash::HashMap;
27use glam::{Mat4, Vec2, Vec3, Vec4};
28use std::sync::Arc;
29
30/// A material parameter value that can be bound to a shader.
31#[derive(Debug, Clone)]
32pub enum MaterialParameter {
33    /// Single float value
34    Float(f32),
35    /// 2D vector
36    Vec2(Vec2),
37    /// 3D vector
38    Vec3(Vec3),
39    /// 4D vector
40    Vec4(Vec4),
41    /// RGBA color
42    Color(Color),
43    /// 4x4 matrix
44    Matrix4(Mat4),
45    /// Array of floats
46    FloatArray(Vec<f32>),
47    /// Array of Vec2
48    Vec2Array(Vec<Vec2>),
49    /// Array of Vec3
50    Vec3Array(Vec<Vec3>),
51    /// Array of Vec4
52    Vec4Array(Vec<Vec4>),
53}
54
55impl MaterialParameter {
56    /// Convert parameter to bytes for GPU upload.
57    pub fn as_bytes(&self) -> Vec<u8> {
58        match self {
59            MaterialParameter::Float(v) => bytemuck::bytes_of(v).to_vec(),
60            MaterialParameter::Vec2(v) => bytemuck::bytes_of(v).to_vec(),
61            MaterialParameter::Vec3(v) => {
62                // Pad Vec3 to 16 bytes for alignment
63                let mut bytes = Vec::with_capacity(16);
64                bytes.extend_from_slice(bytemuck::bytes_of(v));
65                bytes.extend_from_slice(&[0u8; 4]); // padding
66                bytes
67            }
68            MaterialParameter::Vec4(v) => bytemuck::bytes_of(v).to_vec(),
69            MaterialParameter::Color(c) => bytemuck::bytes_of(c).to_vec(),
70            MaterialParameter::Matrix4(m) => bytemuck::bytes_of(m).to_vec(),
71            MaterialParameter::FloatArray(arr) => bytemuck::cast_slice(arr).to_vec(),
72            MaterialParameter::Vec2Array(arr) => bytemuck::cast_slice(arr).to_vec(),
73            MaterialParameter::Vec3Array(arr) => {
74                // Each Vec3 needs padding to 16 bytes
75                let mut bytes = Vec::with_capacity(arr.len() * 16);
76                for v in arr {
77                    bytes.extend_from_slice(bytemuck::bytes_of(v));
78                    bytes.extend_from_slice(&[0u8; 4]); // padding
79                }
80                bytes
81            }
82            MaterialParameter::Vec4Array(arr) => bytemuck::cast_slice(arr).to_vec(),
83        }
84    }
85
86    /// Get the size of the parameter in bytes (including padding).
87    pub fn size(&self) -> u64 {
88        match self {
89            MaterialParameter::Float(_) => 4,
90            MaterialParameter::Vec2(_) => 8,
91            MaterialParameter::Vec3(_) => 16, // Padded
92            MaterialParameter::Vec4(_) => 16,
93            MaterialParameter::Color(_) => 16,
94            MaterialParameter::Matrix4(_) => 64,
95            MaterialParameter::FloatArray(arr) => (arr.len() * 4) as u64,
96            MaterialParameter::Vec2Array(arr) => (arr.len() * 8) as u64,
97            MaterialParameter::Vec3Array(arr) => (arr.len() * 16) as u64, // Padded
98            MaterialParameter::Vec4Array(arr) => (arr.len() * 16) as u64,
99        }
100    }
101}
102
103/// Texture binding information for a material.
104#[derive(Debug, Clone)]
105pub struct MaterialTexture {
106    /// The texture to bind
107    pub texture: wgpu::Texture,
108    /// The texture view
109    pub view: wgpu::TextureView,
110    /// Optional sampler (if None, a default linear sampler will be used)
111    pub sampler: Option<wgpu::Sampler>,
112}
113
114/// Pipeline state configuration for a material.
115#[derive(Debug, Clone)]
116pub struct PipelineState {
117    /// Primitive topology (default: TriangleList)
118    pub topology: wgpu::PrimitiveTopology,
119    /// Cull mode (default: Some(Back))
120    pub cull_mode: Option<wgpu::Face>,
121    /// Front face winding (default: Ccw)
122    pub front_face: wgpu::FrontFace,
123    /// Polygon mode (default: Fill)
124    pub polygon_mode: wgpu::PolygonMode,
125    /// Depth test enabled (default: false)
126    pub depth_test: bool,
127    /// Depth write enabled (default: false)
128    pub depth_write: bool,
129    /// Blend mode (default: None - opaque)
130    pub blend: Option<wgpu::BlendState>,
131}
132
133impl Default for PipelineState {
134    fn default() -> Self {
135        Self {
136            topology: wgpu::PrimitiveTopology::TriangleList,
137            cull_mode: Some(wgpu::Face::Back),
138            front_face: wgpu::FrontFace::Ccw,
139            polygon_mode: wgpu::PolygonMode::Fill,
140            depth_test: false,
141            depth_write: false,
142            blend: None,
143        }
144    }
145}
146
147/// A material manages shader parameters, textures, and pipeline state.
148pub struct Material {
149    /// The shader module
150    shader: Arc<wgpu::ShaderModule>,
151    /// Named parameters
152    parameters: HashMap<String, MaterialParameter>,
153    /// Named textures
154    textures: HashMap<String, MaterialTexture>,
155    /// Pipeline state
156    pipeline_state: PipelineState,
157    /// Graphics context reference
158    context: Arc<GraphicsContext>,
159    /// Cached uniform buffer
160    uniform_buffer: Option<wgpu::Buffer>,
161    /// Cached bind group layout
162    bind_group_layout: Option<wgpu::BindGroupLayout>,
163    /// Cached bind group
164    bind_group: Option<wgpu::BindGroup>,
165    /// Dirty flag - set to true when parameters/textures change
166    dirty: bool,
167}
168
169impl Material {
170    /// Create a new material with a shader.
171    pub fn new(shader: Arc<wgpu::ShaderModule>, context: Arc<GraphicsContext>) -> Self {
172        Self {
173            shader,
174            parameters: HashMap::default(),
175            textures: HashMap::default(),
176            pipeline_state: PipelineState::default(),
177            context,
178            uniform_buffer: None,
179            bind_group_layout: None,
180            bind_group: None,
181            dirty: true,
182        }
183    }
184
185    /// Create a material from a shader source string.
186    pub fn from_source(
187        source: &str,
188        label: Option<&str>,
189        context: Arc<GraphicsContext>,
190    ) -> Self {
191        let shader = context
192            .device
193            .create_shader_module(wgpu::ShaderModuleDescriptor {
194                label,
195                source: wgpu::ShaderSource::Wgsl(source.into()),
196            });
197        Self::new(Arc::new(shader), context)
198    }
199
200    /// Set a parameter by name.
201    pub fn set_parameter(&mut self, name: impl Into<String>, value: MaterialParameter) {
202        self.parameters.insert(name.into(), value);
203        self.dirty = true;
204    }
205
206    /// Get a parameter by name.
207    pub fn get_parameter(&self, name: &str) -> Option<&MaterialParameter> {
208        self.parameters.get(name)
209    }
210
211    /// Set a texture by name.
212    pub fn set_texture(&mut self, name: impl Into<String>, texture: MaterialTexture) {
213        self.textures.insert(name.into(), texture);
214        self.dirty = true;
215    }
216
217    /// Get a texture by name.
218    pub fn get_texture(&self, name: &str) -> Option<&MaterialTexture> {
219        self.textures.get(name)
220    }
221
222    /// Set the pipeline state.
223    pub fn set_pipeline_state(&mut self, state: PipelineState) {
224        self.pipeline_state = state;
225    }
226
227    /// Get the pipeline state.
228    pub fn pipeline_state(&self) -> &PipelineState {
229        &self.pipeline_state
230    }
231
232    /// Get the shader module.
233    pub fn shader(&self) -> &wgpu::ShaderModule {
234        &self.shader
235    }
236
237    /// Update GPU resources if dirty.
238    fn update_resources(&mut self) {
239        if !self.dirty {
240            return;
241        }
242
243        // Calculate total uniform buffer size
244        let mut uniform_size = 0u64;
245        for param in self.parameters.values() {
246            uniform_size += param.size();
247            // Add padding for alignment
248            if uniform_size % 16 != 0 {
249                uniform_size += 16 - (uniform_size % 16);
250            }
251        }
252
253        // Create or update uniform buffer
254        if uniform_size > 0 {
255            let mut uniform_data = Vec::new();
256            for param in self.parameters.values() {
257                uniform_data.extend_from_slice(&param.as_bytes());
258                // Add padding for alignment
259                let current_size = uniform_data.len() as u64;
260                if current_size % 16 != 0 {
261                    let padding = 16 - (current_size % 16);
262                    uniform_data.extend(vec![0u8; padding as usize]);
263                }
264            }
265
266            if let Some(buffer) = &self.uniform_buffer {
267                // Update existing buffer
268                self.context.queue.write_buffer(buffer, 0, &uniform_data);
269            } else {
270                // Create new buffer
271                let buffer = self
272                    .context
273                    .device
274                    .create_buffer(&wgpu::BufferDescriptor {
275                        label: Some("Material Uniform Buffer"),
276                        size: uniform_size,
277                        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
278                        mapped_at_creation: false,
279                    });
280                self.context.queue.write_buffer(&buffer, 0, &uniform_data);
281                self.uniform_buffer = Some(buffer);
282            }
283        }
284
285        // Rebuild bind group layout and bind group
286        self.rebuild_bind_groups();
287
288        self.dirty = false;
289    }
290
291    /// Rebuild bind group layout and bind group.
292    fn rebuild_bind_groups(&mut self) {
293        let mut layout_entries = Vec::new();
294        let mut bind_entries = Vec::new();
295        let mut binding = 0u32;
296
297        // Add uniform buffer binding if present
298        if self.uniform_buffer.is_some() {
299            layout_entries.push(wgpu::BindGroupLayoutEntry {
300                binding,
301                visibility: wgpu::ShaderStages::VERTEX_FRAGMENT,
302                ty: wgpu::BindingType::Buffer {
303                    ty: wgpu::BufferBindingType::Uniform,
304                    has_dynamic_offset: false,
305                    min_binding_size: None,
306                },
307                count: None,
308            });
309
310            bind_entries.push(wgpu::BindGroupEntry {
311                binding,
312                resource: self.uniform_buffer.as_ref().unwrap().as_entire_binding(),
313            });
314
315            binding += 1;
316        }
317
318        // Add texture bindings
319        for texture in self.textures.values() {
320            // Texture binding
321            layout_entries.push(wgpu::BindGroupLayoutEntry {
322                binding,
323                visibility: wgpu::ShaderStages::FRAGMENT,
324                ty: wgpu::BindingType::Texture {
325                    sample_type: wgpu::TextureSampleType::Float { filterable: true },
326                    view_dimension: wgpu::TextureViewDimension::D2,
327                    multisampled: false,
328                },
329                count: None,
330            });
331
332            bind_entries.push(wgpu::BindGroupEntry {
333                binding,
334                resource: wgpu::BindingResource::TextureView(&texture.view),
335            });
336
337            binding += 1;
338
339            // Sampler binding
340            layout_entries.push(wgpu::BindGroupLayoutEntry {
341                binding,
342                visibility: wgpu::ShaderStages::FRAGMENT,
343                ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering),
344                count: None,
345            });
346
347            // Use provided sampler or create default
348            let sampler = if let Some(ref s) = texture.sampler {
349                s
350            } else {
351                // Create a default linear sampler (this should be cached in practice)
352                // For now, we'll use a temporary one
353                // TODO: Add sampler cache to Material or GraphicsContext
354                unimplemented!("Default sampler not yet implemented - please provide sampler")
355            };
356
357            bind_entries.push(wgpu::BindGroupEntry {
358                binding,
359                resource: wgpu::BindingResource::Sampler(sampler),
360            });
361
362            binding += 1;
363        }
364
365        // Create bind group layout
366        let layout = self
367            .context
368            .device
369            .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
370                label: Some("Material Bind Group Layout"),
371                entries: &layout_entries,
372            });
373
374        // Create bind group
375        let bind_group = self
376            .context
377            .device
378            .create_bind_group(&wgpu::BindGroupDescriptor {
379                label: Some("Material Bind Group"),
380                layout: &layout,
381                entries: &bind_entries,
382            });
383
384        self.bind_group_layout = Some(layout);
385        self.bind_group = Some(bind_group);
386    }
387
388    /// Bind this material's resources to a render pass.
389    ///
390    /// This will update GPU resources if needed and set the bind group.
391    ///
392    /// # Arguments
393    ///
394    /// * `pass` - The render pass to bind to
395    /// * `bind_group_index` - The bind group index (default is usually 0)
396    pub fn bind<'a>(&'a mut self, pass: &mut wgpu::RenderPass<'a>, bind_group_index: u32) {
397        self.update_resources();
398
399        if let Some(ref bind_group) = self.bind_group {
400            pass.set_bind_group(bind_group_index, bind_group, &[]);
401        }
402    }
403
404    /// Get the bind group layout (creates it if needed).
405    ///
406    /// This is useful when creating render pipelines.
407    pub fn bind_group_layout(&mut self) -> &wgpu::BindGroupLayout {
408        if self.dirty || self.bind_group_layout.is_none() {
409            self.update_resources();
410        }
411        self.bind_group_layout
412            .as_ref()
413            .expect("Bind group layout should be created")
414    }
415}
416
417/// Builder for creating materials with a fluent API.
418pub struct MaterialBuilder {
419    shader: Option<Arc<wgpu::ShaderModule>>,
420    parameters: HashMap<String, MaterialParameter>,
421    textures: HashMap<String, MaterialTexture>,
422    pipeline_state: PipelineState,
423    context: Arc<GraphicsContext>,
424}
425
426impl MaterialBuilder {
427    /// Create a new material builder.
428    pub fn new(context: Arc<GraphicsContext>) -> Self {
429        Self {
430            shader: None,
431            parameters: HashMap::default(),
432            textures: HashMap::default(),
433            pipeline_state: PipelineState::default(),
434            context,
435        }
436    }
437
438    /// Set the shader from a module.
439    pub fn shader(mut self, shader: Arc<wgpu::ShaderModule>) -> Self {
440        self.shader = Some(shader);
441        self
442    }
443
444    /// Set the shader from source code.
445    pub fn shader_source(mut self, source: &str, label: Option<&str>) -> Self {
446        let shader = self
447            .context
448            .device
449            .create_shader_module(wgpu::ShaderModuleDescriptor {
450                label,
451                source: wgpu::ShaderSource::Wgsl(source.into()),
452            });
453        self.shader = Some(Arc::new(shader));
454        self
455    }
456
457    /// Set a parameter.
458    pub fn parameter(mut self, name: impl Into<String>, value: MaterialParameter) -> Self {
459        self.parameters.insert(name.into(), value);
460        self
461    }
462
463    /// Set a texture.
464    pub fn texture(mut self, name: impl Into<String>, texture: MaterialTexture) -> Self {
465        self.textures.insert(name.into(), texture);
466        self
467    }
468
469    /// Set the pipeline state.
470    pub fn pipeline_state(mut self, state: PipelineState) -> Self {
471        self.pipeline_state = state;
472        self
473    }
474
475    /// Build the material.
476    pub fn build(self) -> Material {
477        let shader = self.shader.expect("Shader is required");
478        let mut material = Material::new(shader, self.context);
479        material.parameters = self.parameters;
480        material.textures = self.textures;
481        material.pipeline_state = self.pipeline_state;
482        material.dirty = true;
483        material
484    }
485}