Skip to main content

oxide_renderer/gltf/
loader.rs

1//! glTF model loader
2
3use std::path::Path;
4
5use glam::{Quat, Vec3};
6use gltf::buffer::Data;
7use gltf::mesh::Mode;
8use wgpu::{Device, Queue};
9
10use crate::mesh::Mesh3D;
11use crate::mesh::Vertex3D;
12
13#[derive(thiserror::Error, Debug)]
14pub enum GltfError {
15    #[error("Failed to load glTF file '{path}': {source}")]
16    Load { path: String, source: gltf::Error },
17    #[error("Failed to read glTF buffers: {0}")]
18    Buffer(String),
19    #[error("Primitive mode {0:?} is not supported. Only triangles are supported.")]
20    UnsupportedMode(Mode),
21    #[error("Mesh has no positions")]
22    MissingPositions,
23}
24
25/// Result of loading a glTF file.
26pub struct GltfScene {
27    /// Loaded meshes with their names.
28    pub meshes: Vec<(String, Mesh3D)>,
29    /// Node hierarchy information for spawning entities.
30    pub nodes: Vec<GltfNode>,
31}
32
33/// Represents a node in the glTF hierarchy.
34#[derive(Clone, Debug)]
35pub struct GltfNode {
36    /// Name of the node (if available).
37    pub name: Option<String>,
38    /// Index of the mesh (if this node has a mesh).
39    pub mesh_index: Option<usize>,
40    /// Local transform: position.
41    pub translation: Vec3,
42    /// Local transform: rotation.
43    pub rotation: Quat,
44    /// Local transform: scale.
45    pub scale: Vec3,
46    /// Child nodes.
47    pub children: Vec<GltfNode>,
48}
49
50impl Default for GltfNode {
51    fn default() -> Self {
52        Self {
53            name: None,
54            mesh_index: None,
55            translation: Vec3::ZERO,
56            rotation: Quat::IDENTITY,
57            scale: Vec3::ONE,
58            children: Vec::new(),
59        }
60    }
61}
62
63/// Loads a glTF file and extracts meshes.
64pub fn load_gltf(
65    device: &Device,
66    queue: &Queue,
67    path: impl AsRef<Path>,
68) -> Result<GltfScene, GltfError> {
69    let path = path.as_ref();
70    let path_str = path.display().to_string();
71
72    // Load the glTF document and buffers
73    let (document, buffers, _images) = gltf::import(path).map_err(|source| GltfError::Load {
74        path: path_str.clone(),
75        source,
76    })?;
77
78    // Extract meshes
79    let mut meshes = Vec::new();
80    for (mesh_idx, mesh) in document.meshes().enumerate() {
81        for (prim_idx, primitive) in mesh.primitives().enumerate() {
82            // Only support triangle mode
83            if primitive.mode() != Mode::Triangles {
84                return Err(GltfError::UnsupportedMode(primitive.mode()));
85            }
86
87            let mesh_name = format!("mesh_{}_prim{}", mesh_idx, prim_idx);
88            let loaded_mesh = load_primitive(device, queue, &primitive, &buffers, &mesh_name)?;
89            meshes.push((mesh_name, loaded_mesh));
90        }
91    }
92
93    // Extract node hierarchy
94    let nodes = extract_nodes(&document, &meshes);
95
96    Ok(GltfScene { meshes, nodes })
97}
98
99/// Extracts the node hierarchy from a glTF document.
100fn extract_nodes(document: &gltf::Document, meshes: &[(String, Mesh3D)]) -> Vec<GltfNode> {
101    let scenes: Vec<_> = document.scenes().collect();
102    let scene = scenes.first();
103
104    match scene {
105        Some(scene) => scene
106            .nodes()
107            .enumerate()
108            .map(|(idx, node)| convert_node(idx, &node, meshes))
109            .collect(),
110        None => Vec::new(),
111    }
112}
113
114/// Converts a glTF node to our GltfNode type.
115fn convert_node(node_idx: usize, node: &gltf::Node, meshes: &[(String, Mesh3D)]) -> GltfNode {
116    let (t, r, s) = node.transform().decomposed();
117
118    // Find mesh index if this node has a mesh
119    let mesh_index = node.mesh().map(|mesh| {
120        // Find the index in our meshes vector
121        let mesh_idx = mesh.index();
122        meshes
123            .iter()
124            .position(|(name, _)| name.starts_with(&format!("mesh_{}_", mesh_idx)))
125            .unwrap_or(0)
126    });
127
128    GltfNode {
129        name: Some(format!("node_{}", node_idx)),
130        mesh_index,
131        translation: Vec3::new(t[0], t[1], t[2]),
132        rotation: Quat::from_xyzw(r[0], r[1], r[2], r[3]),
133        scale: Vec3::new(s[0], s[1], s[2]),
134        children: node
135            .children()
136            .enumerate()
137            .map(|(idx, child)| convert_node(idx, &child, meshes))
138            .collect(),
139    }
140}
141
142/// Loads a single primitive as a Mesh3D.
143fn load_primitive(
144    device: &Device,
145    _queue: &Queue,
146    primitive: &gltf::Primitive,
147    buffers: &[Data],
148    name: &str,
149) -> Result<Mesh3D, GltfError> {
150    let reader = primitive.reader(|buffer| Some(&buffers[buffer.index()]));
151
152    // Read positions (required)
153    let positions: Vec<[f32; 3]> = reader
154        .read_positions()
155        .ok_or(GltfError::MissingPositions)?
156        .collect();
157
158    // Read normals (optional, default to up)
159    let normals: Vec<[f32; 3]> = reader
160        .read_normals()
161        .map(|iter| iter.collect())
162        .unwrap_or_else(|| vec![[0.0, 1.0, 0.0]; positions.len()]);
163
164    // Read UVs (optional, default to 0,0)
165    let uvs: Vec<[f32; 2]> = reader
166        .read_tex_coords(0)
167        .map(|tex| tex.into_f32().collect())
168        .unwrap_or_else(|| vec![[0.0, 0.0]; positions.len()]);
169
170    // Build vertices
171    let vertices: Vec<Vertex3D> = positions
172        .iter()
173        .zip(normals.iter())
174        .zip(uvs.iter())
175        .map(|((&pos, &normal), &uv)| Vertex3D {
176            position: pos,
177            normal,
178            uv,
179        })
180        .collect();
181
182    // Read indices
183    let indices: Vec<u16> = reader
184        .read_indices()
185        .map(|indices| indices.into_u32().map(|i| i as u16).collect())
186        .unwrap_or_else(|| (0..vertices.len() as u16).collect());
187
188    // Create the mesh
189    Ok(Mesh3D::create(device, &vertices, &indices, Some(name)))
190}