Skip to main content

astrelis_render/
mesh.rs

1//! Mesh abstraction for high-level geometry management.
2//!
3//! Provides a declarative API for creating and rendering meshes with vertices, indices,
4//! and common primitive shapes.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use astrelis_render::*;
10//! use glam::Vec3;
11//!
12//! // Create a mesh from vertices
13//! let mesh = MeshBuilder::new()
14//!     .with_positions(vec![
15//!         Vec3::new(-0.5, -0.5, 0.0),
16//!         Vec3::new(0.5, -0.5, 0.0),
17//!         Vec3::new(0.0, 0.5, 0.0),
18//!     ])
19//!     .with_indices(vec![0, 1, 2])
20//!     .build(&ctx);
21//!
22//! // Draw the mesh
23//! mesh.draw(&mut pass);
24//!
25//! // Or create a primitive
26//! let cube = Mesh::cube(&ctx, 1.0);
27//! cube.draw_instanced(&mut pass, 10);
28//! ```
29
30use crate::GraphicsContext;
31use glam::{Vec2, Vec3};
32use std::sync::Arc;
33
34/// Vertex format specification.
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum VertexFormat {
37    /// Position only (Vec3)
38    Position,
39    /// Position + Normal (Vec3 + Vec3)
40    PositionNormal,
41    /// Position + UV (Vec3 + Vec2)
42    PositionUv,
43    /// Position + Normal + UV (Vec3 + Vec3 + Vec2)
44    PositionNormalUv,
45    /// Position + Normal + UV + Color (Vec3 + Vec3 + Vec2 + Vec4)
46    PositionNormalUvColor,
47}
48
49impl VertexFormat {
50    /// Get the size of a single vertex in bytes.
51    pub fn vertex_size(&self) -> u64 {
52        match self {
53            VertexFormat::Position => 12,              // 3 floats
54            VertexFormat::PositionNormal => 24,        // 6 floats
55            VertexFormat::PositionUv => 20,            // 5 floats
56            VertexFormat::PositionNormalUv => 32,      // 8 floats
57            VertexFormat::PositionNormalUvColor => 48, // 12 floats
58        }
59    }
60
61    /// Get the WGPU vertex buffer layout for this format.
62    pub fn buffer_layout(&self) -> wgpu::VertexBufferLayout<'static> {
63        match self {
64            VertexFormat::Position => wgpu::VertexBufferLayout {
65                array_stride: 12,
66                step_mode: wgpu::VertexStepMode::Vertex,
67                attributes: &[wgpu::VertexAttribute {
68                    format: wgpu::VertexFormat::Float32x3,
69                    offset: 0,
70                    shader_location: 0,
71                }],
72            },
73            VertexFormat::PositionNormal => wgpu::VertexBufferLayout {
74                array_stride: 24,
75                step_mode: wgpu::VertexStepMode::Vertex,
76                attributes: &[
77                    wgpu::VertexAttribute {
78                        format: wgpu::VertexFormat::Float32x3,
79                        offset: 0,
80                        shader_location: 0,
81                    },
82                    wgpu::VertexAttribute {
83                        format: wgpu::VertexFormat::Float32x3,
84                        offset: 12,
85                        shader_location: 1,
86                    },
87                ],
88            },
89            VertexFormat::PositionUv => wgpu::VertexBufferLayout {
90                array_stride: 20,
91                step_mode: wgpu::VertexStepMode::Vertex,
92                attributes: &[
93                    wgpu::VertexAttribute {
94                        format: wgpu::VertexFormat::Float32x3,
95                        offset: 0,
96                        shader_location: 0,
97                    },
98                    wgpu::VertexAttribute {
99                        format: wgpu::VertexFormat::Float32x2,
100                        offset: 12,
101                        shader_location: 1,
102                    },
103                ],
104            },
105            VertexFormat::PositionNormalUv => wgpu::VertexBufferLayout {
106                array_stride: 32,
107                step_mode: wgpu::VertexStepMode::Vertex,
108                attributes: &[
109                    wgpu::VertexAttribute {
110                        format: wgpu::VertexFormat::Float32x3,
111                        offset: 0,
112                        shader_location: 0,
113                    },
114                    wgpu::VertexAttribute {
115                        format: wgpu::VertexFormat::Float32x3,
116                        offset: 12,
117                        shader_location: 1,
118                    },
119                    wgpu::VertexAttribute {
120                        format: wgpu::VertexFormat::Float32x2,
121                        offset: 24,
122                        shader_location: 2,
123                    },
124                ],
125            },
126            VertexFormat::PositionNormalUvColor => wgpu::VertexBufferLayout {
127                array_stride: 48,
128                step_mode: wgpu::VertexStepMode::Vertex,
129                attributes: &[
130                    wgpu::VertexAttribute {
131                        format: wgpu::VertexFormat::Float32x3,
132                        offset: 0,
133                        shader_location: 0,
134                    },
135                    wgpu::VertexAttribute {
136                        format: wgpu::VertexFormat::Float32x3,
137                        offset: 12,
138                        shader_location: 1,
139                    },
140                    wgpu::VertexAttribute {
141                        format: wgpu::VertexFormat::Float32x2,
142                        offset: 24,
143                        shader_location: 2,
144                    },
145                    wgpu::VertexAttribute {
146                        format: wgpu::VertexFormat::Float32x4,
147                        offset: 32,
148                        shader_location: 3,
149                    },
150                ],
151            },
152        }
153    }
154}
155
156/// A mesh containing vertex and optional index data.
157pub struct Mesh {
158    /// Vertex buffer
159    vertex_buffer: wgpu::Buffer,
160    /// Optional index buffer
161    index_buffer: Option<wgpu::Buffer>,
162    /// Vertex format
163    vertex_format: VertexFormat,
164    /// Primitive topology
165    topology: wgpu::PrimitiveTopology,
166    /// Index format (if indexed)
167    index_format: Option<wgpu::IndexFormat>,
168    /// Number of vertices
169    vertex_count: u32,
170    /// Number of indices (if indexed)
171    index_count: Option<u32>,
172    /// Graphics context reference
173    _context: Arc<GraphicsContext>,
174}
175
176impl Mesh {
177    /// Draw the mesh.
178    pub fn draw<'a>(&'a self, pass: &mut wgpu::RenderPass<'a>) {
179        pass.set_vertex_buffer(0, self.vertex_buffer.slice(..));
180
181        if let Some(ref index_buffer) = self.index_buffer {
182            let index_format = self.index_format.expect("Index format must be set");
183            pass.set_index_buffer(index_buffer.slice(..), index_format);
184            pass.draw_indexed(0..self.index_count.unwrap(), 0, 0..1);
185        } else {
186            pass.draw(0..self.vertex_count, 0..1);
187        }
188    }
189
190    /// Draw the mesh instanced.
191    pub fn draw_instanced<'a>(&'a self, pass: &mut wgpu::RenderPass<'a>, instances: u32) {
192        pass.set_vertex_buffer(0, self.vertex_buffer.slice(..));
193
194        if let Some(ref index_buffer) = self.index_buffer {
195            let index_format = self.index_format.expect("Index format must be set");
196            pass.set_index_buffer(index_buffer.slice(..), index_format);
197            pass.draw_indexed(0..self.index_count.unwrap(), 0, 0..instances);
198        } else {
199            pass.draw(0..self.vertex_count, 0..instances);
200        }
201    }
202
203    /// Get the vertex format.
204    pub fn vertex_format(&self) -> VertexFormat {
205        self.vertex_format
206    }
207
208    /// Get the primitive topology.
209    pub fn topology(&self) -> wgpu::PrimitiveTopology {
210        self.topology
211    }
212
213    /// Get the vertex count.
214    pub fn vertex_count(&self) -> u32 {
215        self.vertex_count
216    }
217
218    /// Get the index count (if indexed).
219    pub fn index_count(&self) -> Option<u32> {
220        self.index_count
221    }
222
223    // ===== Primitive Generators =====
224
225    /// Create a unit cube mesh (1x1x1) centered at origin.
226    pub fn cube(ctx: Arc<GraphicsContext>, size: f32) -> Self {
227        let half = size / 2.0;
228
229        let positions = vec![
230            // Front face
231            Vec3::new(-half, -half, half),
232            Vec3::new(half, -half, half),
233            Vec3::new(half, half, half),
234            Vec3::new(-half, half, half),
235            // Back face
236            Vec3::new(-half, -half, -half),
237            Vec3::new(-half, half, -half),
238            Vec3::new(half, half, -half),
239            Vec3::new(half, -half, -half),
240            // Top face
241            Vec3::new(-half, half, -half),
242            Vec3::new(-half, half, half),
243            Vec3::new(half, half, half),
244            Vec3::new(half, half, -half),
245            // Bottom face
246            Vec3::new(-half, -half, -half),
247            Vec3::new(half, -half, -half),
248            Vec3::new(half, -half, half),
249            Vec3::new(-half, -half, half),
250            // Right face
251            Vec3::new(half, -half, -half),
252            Vec3::new(half, half, -half),
253            Vec3::new(half, half, half),
254            Vec3::new(half, -half, half),
255            // Left face
256            Vec3::new(-half, -half, -half),
257            Vec3::new(-half, -half, half),
258            Vec3::new(-half, half, half),
259            Vec3::new(-half, half, -half),
260        ];
261
262        let normals = vec![
263            // Front
264            Vec3::new(0.0, 0.0, 1.0),
265            Vec3::new(0.0, 0.0, 1.0),
266            Vec3::new(0.0, 0.0, 1.0),
267            Vec3::new(0.0, 0.0, 1.0),
268            // Back
269            Vec3::new(0.0, 0.0, -1.0),
270            Vec3::new(0.0, 0.0, -1.0),
271            Vec3::new(0.0, 0.0, -1.0),
272            Vec3::new(0.0, 0.0, -1.0),
273            // Top
274            Vec3::new(0.0, 1.0, 0.0),
275            Vec3::new(0.0, 1.0, 0.0),
276            Vec3::new(0.0, 1.0, 0.0),
277            Vec3::new(0.0, 1.0, 0.0),
278            // Bottom
279            Vec3::new(0.0, -1.0, 0.0),
280            Vec3::new(0.0, -1.0, 0.0),
281            Vec3::new(0.0, -1.0, 0.0),
282            Vec3::new(0.0, -1.0, 0.0),
283            // Right
284            Vec3::new(1.0, 0.0, 0.0),
285            Vec3::new(1.0, 0.0, 0.0),
286            Vec3::new(1.0, 0.0, 0.0),
287            Vec3::new(1.0, 0.0, 0.0),
288            // Left
289            Vec3::new(-1.0, 0.0, 0.0),
290            Vec3::new(-1.0, 0.0, 0.0),
291            Vec3::new(-1.0, 0.0, 0.0),
292            Vec3::new(-1.0, 0.0, 0.0),
293        ];
294
295        let uvs = vec![
296            // Front
297            Vec2::new(0.0, 1.0),
298            Vec2::new(1.0, 1.0),
299            Vec2::new(1.0, 0.0),
300            Vec2::new(0.0, 0.0),
301            // Back
302            Vec2::new(1.0, 1.0),
303            Vec2::new(1.0, 0.0),
304            Vec2::new(0.0, 0.0),
305            Vec2::new(0.0, 1.0),
306            // Top
307            Vec2::new(0.0, 0.0),
308            Vec2::new(0.0, 1.0),
309            Vec2::new(1.0, 1.0),
310            Vec2::new(1.0, 0.0),
311            // Bottom
312            Vec2::new(0.0, 0.0),
313            Vec2::new(1.0, 0.0),
314            Vec2::new(1.0, 1.0),
315            Vec2::new(0.0, 1.0),
316            // Right
317            Vec2::new(1.0, 1.0),
318            Vec2::new(1.0, 0.0),
319            Vec2::new(0.0, 0.0),
320            Vec2::new(0.0, 1.0),
321            // Left
322            Vec2::new(0.0, 1.0),
323            Vec2::new(1.0, 1.0),
324            Vec2::new(1.0, 0.0),
325            Vec2::new(0.0, 0.0),
326        ];
327
328        #[rustfmt::skip]
329        let indices: Vec<u32> = vec![
330            0, 1, 2, 2, 3, 0,       // Front
331            4, 5, 6, 6, 7, 4,       // Back
332            8, 9, 10, 10, 11, 8,    // Top
333            12, 13, 14, 14, 15, 12, // Bottom
334            16, 17, 18, 18, 19, 16, // Right
335            20, 21, 22, 22, 23, 20, // Left
336        ];
337
338        MeshBuilder::new()
339            .with_positions(positions)
340            .with_normals(normals)
341            .with_uvs(uvs)
342            .with_indices(indices)
343            .build(ctx)
344    }
345
346    /// Create a plane mesh (XZ plane) centered at origin.
347    pub fn plane(ctx: Arc<GraphicsContext>, width: f32, depth: f32) -> Self {
348        let hw = width / 2.0;
349        let hd = depth / 2.0;
350
351        let positions = vec![
352            Vec3::new(-hw, 0.0, -hd),
353            Vec3::new(hw, 0.0, -hd),
354            Vec3::new(hw, 0.0, hd),
355            Vec3::new(-hw, 0.0, hd),
356        ];
357
358        let normals = vec![
359            Vec3::new(0.0, 1.0, 0.0),
360            Vec3::new(0.0, 1.0, 0.0),
361            Vec3::new(0.0, 1.0, 0.0),
362            Vec3::new(0.0, 1.0, 0.0),
363        ];
364
365        let uvs = vec![
366            Vec2::new(0.0, 1.0),
367            Vec2::new(1.0, 1.0),
368            Vec2::new(1.0, 0.0),
369            Vec2::new(0.0, 0.0),
370        ];
371
372        let indices = vec![0, 1, 2, 2, 3, 0];
373
374        MeshBuilder::new()
375            .with_positions(positions)
376            .with_normals(normals)
377            .with_uvs(uvs)
378            .with_indices(indices)
379            .build(ctx)
380    }
381
382    /// Create a sphere mesh using UV sphere generation.
383    pub fn sphere(ctx: Arc<GraphicsContext>, radius: f32, segments: u32, rings: u32) -> Self {
384        let mut positions = Vec::new();
385        let mut normals = Vec::new();
386        let mut uvs = Vec::new();
387        let mut indices = Vec::new();
388
389        // Generate vertices
390        for ring in 0..=rings {
391            let theta = ring as f32 * std::f32::consts::PI / rings as f32;
392            let sin_theta = theta.sin();
393            let cos_theta = theta.cos();
394
395            for segment in 0..=segments {
396                let phi = segment as f32 * 2.0 * std::f32::consts::PI / segments as f32;
397                let sin_phi = phi.sin();
398                let cos_phi = phi.cos();
399
400                let x = sin_theta * cos_phi;
401                let y = cos_theta;
402                let z = sin_theta * sin_phi;
403
404                positions.push(Vec3::new(x * radius, y * radius, z * radius));
405                normals.push(Vec3::new(x, y, z));
406                uvs.push(Vec2::new(
407                    segment as f32 / segments as f32,
408                    ring as f32 / rings as f32,
409                ));
410            }
411        }
412
413        // Generate indices
414        for ring in 0..rings {
415            for segment in 0..segments {
416                let first = ring * (segments + 1) + segment;
417                let second = first + segments + 1;
418
419                indices.push(first);
420                indices.push(second);
421                indices.push(first + 1);
422
423                indices.push(second);
424                indices.push(second + 1);
425                indices.push(first + 1);
426            }
427        }
428
429        MeshBuilder::new()
430            .with_positions(positions)
431            .with_normals(normals)
432            .with_uvs(uvs)
433            .with_indices(indices)
434            .build(ctx)
435    }
436}
437
438/// Builder for creating meshes.
439pub struct MeshBuilder {
440    positions: Vec<Vec3>,
441    normals: Option<Vec<Vec3>>,
442    uvs: Option<Vec<Vec2>>,
443    colors: Option<Vec<[f32; 4]>>,
444    indices: Option<Vec<u32>>,
445    topology: wgpu::PrimitiveTopology,
446}
447
448impl MeshBuilder {
449    /// Create a new mesh builder.
450    pub fn new() -> Self {
451        Self {
452            positions: Vec::new(),
453            normals: None,
454            uvs: None,
455            colors: None,
456            indices: None,
457            topology: wgpu::PrimitiveTopology::TriangleList,
458        }
459    }
460
461    /// Set positions.
462    pub fn with_positions(mut self, positions: Vec<Vec3>) -> Self {
463        self.positions = positions;
464        self
465    }
466
467    /// Set normals.
468    pub fn with_normals(mut self, normals: Vec<Vec3>) -> Self {
469        self.normals = Some(normals);
470        self
471    }
472
473    /// Set UVs.
474    pub fn with_uvs(mut self, uvs: Vec<Vec2>) -> Self {
475        self.uvs = Some(uvs);
476        self
477    }
478
479    /// Set vertex colors.
480    pub fn with_colors(mut self, colors: Vec<[f32; 4]>) -> Self {
481        self.colors = Some(colors);
482        self
483    }
484
485    /// Set indices.
486    pub fn with_indices(mut self, indices: Vec<u32>) -> Self {
487        self.indices = Some(indices);
488        self
489    }
490
491    /// Set primitive topology (default: TriangleList).
492    pub fn with_topology(mut self, topology: wgpu::PrimitiveTopology) -> Self {
493        self.topology = topology;
494        self
495    }
496
497    /// Generate flat normals (per-triangle normals).
498    pub fn generate_flat_normals(mut self) -> Self {
499        if self.indices.is_none() {
500            panic!("Cannot generate flat normals without indices");
501        }
502
503        let indices = self.indices.as_ref().unwrap();
504        let mut normals = vec![Vec3::ZERO; self.positions.len()];
505
506        for triangle in indices.chunks(3) {
507            let i0 = triangle[0] as usize;
508            let i1 = triangle[1] as usize;
509            let i2 = triangle[2] as usize;
510
511            let v0 = self.positions[i0];
512            let v1 = self.positions[i1];
513            let v2 = self.positions[i2];
514
515            let edge1 = v1 - v0;
516            let edge2 = v2 - v0;
517            let normal = edge1.cross(edge2).normalize();
518
519            normals[i0] = normal;
520            normals[i1] = normal;
521            normals[i2] = normal;
522        }
523
524        self.normals = Some(normals);
525        self
526    }
527
528    /// Generate smooth normals (averaged per-vertex normals).
529    pub fn generate_smooth_normals(mut self) -> Self {
530        if self.indices.is_none() {
531            panic!("Cannot generate smooth normals without indices");
532        }
533
534        let indices = self.indices.as_ref().unwrap();
535        let mut normals = vec![Vec3::ZERO; self.positions.len()];
536        let mut counts = vec![0u32; self.positions.len()];
537
538        for triangle in indices.chunks(3) {
539            let i0 = triangle[0] as usize;
540            let i1 = triangle[1] as usize;
541            let i2 = triangle[2] as usize;
542
543            let v0 = self.positions[i0];
544            let v1 = self.positions[i1];
545            let v2 = self.positions[i2];
546
547            let edge1 = v1 - v0;
548            let edge2 = v2 - v0;
549            let normal = edge1.cross(edge2);
550
551            normals[i0] += normal;
552            normals[i1] += normal;
553            normals[i2] += normal;
554
555            counts[i0] += 1;
556            counts[i1] += 1;
557            counts[i2] += 1;
558        }
559
560        // Average and normalize
561        for (i, normal) in normals.iter_mut().enumerate() {
562            if counts[i] > 0 {
563                *normal = (*normal / counts[i] as f32).normalize();
564            }
565        }
566
567        self.normals = Some(normals);
568        self
569    }
570
571    /// Build the mesh.
572    pub fn build(self, ctx: Arc<GraphicsContext>) -> Mesh {
573        // Determine vertex format
574        let vertex_format = match (&self.normals, &self.uvs, &self.colors) {
575            (None, None, None) => VertexFormat::Position,
576            (Some(_), None, None) => VertexFormat::PositionNormal,
577            (None, Some(_), None) => VertexFormat::PositionUv,
578            (Some(_), Some(_), None) => VertexFormat::PositionNormalUv,
579            (Some(_), Some(_), Some(_)) => VertexFormat::PositionNormalUvColor,
580            _ => panic!("Invalid vertex format combination"),
581        };
582
583        // Build vertex data
584        let mut vertex_data = Vec::new();
585        for i in 0..self.positions.len() {
586            // Position
587            vertex_data.extend_from_slice(bytemuck::bytes_of(&self.positions[i]));
588
589            // Normal
590            if let Some(ref normals) = self.normals {
591                vertex_data.extend_from_slice(bytemuck::bytes_of(&normals[i]));
592            }
593
594            // UV
595            if let Some(ref uvs) = self.uvs {
596                vertex_data.extend_from_slice(bytemuck::bytes_of(&uvs[i]));
597            }
598
599            // Color
600            if let Some(ref colors) = self.colors {
601                vertex_data.extend_from_slice(bytemuck::bytes_of(&colors[i]));
602            }
603        }
604
605        // Create vertex buffer
606        let vertex_buffer = ctx.device().create_buffer(&wgpu::BufferDescriptor {
607            label: Some("Mesh Vertex Buffer"),
608            size: vertex_data.len() as u64,
609            usage: wgpu::BufferUsages::VERTEX | wgpu::BufferUsages::COPY_DST,
610            mapped_at_creation: false,
611        });
612        ctx.queue().write_buffer(&vertex_buffer, 0, &vertex_data);
613
614        // Create index buffer if present
615        let (index_buffer, index_format, index_count) = if let Some(ref indices) = self.indices {
616            let buffer = ctx.device().create_buffer(&wgpu::BufferDescriptor {
617                label: Some("Mesh Index Buffer"),
618                size: (indices.len() * std::mem::size_of::<u32>()) as u64,
619                usage: wgpu::BufferUsages::INDEX | wgpu::BufferUsages::COPY_DST,
620                mapped_at_creation: false,
621            });
622            ctx.queue()
623                .write_buffer(&buffer, 0, bytemuck::cast_slice(indices));
624            (
625                Some(buffer),
626                Some(wgpu::IndexFormat::Uint32),
627                Some(indices.len() as u32),
628            )
629        } else {
630            (None, None, None)
631        };
632
633        Mesh {
634            vertex_buffer,
635            index_buffer,
636            vertex_format,
637            topology: self.topology,
638            index_format,
639            vertex_count: self.positions.len() as u32,
640            index_count,
641            _context: ctx,
642        }
643    }
644}
645
646impl Default for MeshBuilder {
647    fn default() -> Self {
648        Self::new()
649    }
650}