1use std::path::Path;
4
5use glam::{Quat, Vec3};
6use gltf::buffer::Data;
7use gltf::mesh::Mode;
8use wgpu::{Device, Queue};
9
10use crate::mesh::Mesh3D;
11use crate::mesh::Vertex3D;
12
13#[derive(thiserror::Error, Debug)]
14pub enum GltfError {
15 #[error("Failed to load glTF file '{path}': {source}")]
16 Load { path: String, source: gltf::Error },
17 #[error("Failed to read glTF buffers: {0}")]
18 Buffer(String),
19 #[error("Primitive mode {0:?} is not supported. Only triangles are supported.")]
20 UnsupportedMode(Mode),
21 #[error("Mesh has no positions")]
22 MissingPositions,
23}
24
25pub struct GltfScene {
27 pub meshes: Vec<(String, Mesh3D)>,
29 pub nodes: Vec<GltfNode>,
31}
32
33#[derive(Clone, Debug)]
35pub struct GltfNode {
36 pub name: Option<String>,
38 pub mesh_index: Option<usize>,
40 pub translation: Vec3,
42 pub rotation: Quat,
44 pub scale: Vec3,
46 pub children: Vec<GltfNode>,
48}
49
50impl Default for GltfNode {
51 fn default() -> Self {
52 Self {
53 name: None,
54 mesh_index: None,
55 translation: Vec3::ZERO,
56 rotation: Quat::IDENTITY,
57 scale: Vec3::ONE,
58 children: Vec::new(),
59 }
60 }
61}
62
63pub fn load_gltf(
65 device: &Device,
66 queue: &Queue,
67 path: impl AsRef<Path>,
68) -> Result<GltfScene, GltfError> {
69 let path = path.as_ref();
70 let path_str = path.display().to_string();
71
72 let (document, buffers, _images) = gltf::import(path).map_err(|source| GltfError::Load {
74 path: path_str.clone(),
75 source,
76 })?;
77
78 let mut meshes = Vec::new();
80 for (mesh_idx, mesh) in document.meshes().enumerate() {
81 for (prim_idx, primitive) in mesh.primitives().enumerate() {
82 if primitive.mode() != Mode::Triangles {
84 return Err(GltfError::UnsupportedMode(primitive.mode()));
85 }
86
87 let mesh_name = format!("mesh_{}_prim{}", mesh_idx, prim_idx);
88 let loaded_mesh = load_primitive(device, queue, &primitive, &buffers, &mesh_name)?;
89 meshes.push((mesh_name, loaded_mesh));
90 }
91 }
92
93 let nodes = extract_nodes(&document, &meshes);
95
96 Ok(GltfScene { meshes, nodes })
97}
98
99fn extract_nodes(document: &gltf::Document, meshes: &[(String, Mesh3D)]) -> Vec<GltfNode> {
101 let scenes: Vec<_> = document.scenes().collect();
102 let scene = scenes.first();
103
104 match scene {
105 Some(scene) => scene
106 .nodes()
107 .enumerate()
108 .map(|(idx, node)| convert_node(idx, &node, meshes))
109 .collect(),
110 None => Vec::new(),
111 }
112}
113
114fn convert_node(node_idx: usize, node: &gltf::Node, meshes: &[(String, Mesh3D)]) -> GltfNode {
116 let (t, r, s) = node.transform().decomposed();
117
118 let mesh_index = node.mesh().map(|mesh| {
120 let mesh_idx = mesh.index();
122 meshes
123 .iter()
124 .position(|(name, _)| name.starts_with(&format!("mesh_{}_", mesh_idx)))
125 .unwrap_or(0)
126 });
127
128 GltfNode {
129 name: Some(format!("node_{}", node_idx)),
130 mesh_index,
131 translation: Vec3::new(t[0], t[1], t[2]),
132 rotation: Quat::from_xyzw(r[0], r[1], r[2], r[3]),
133 scale: Vec3::new(s[0], s[1], s[2]),
134 children: node
135 .children()
136 .enumerate()
137 .map(|(idx, child)| convert_node(idx, &child, meshes))
138 .collect(),
139 }
140}
141
142fn load_primitive(
144 device: &Device,
145 _queue: &Queue,
146 primitive: &gltf::Primitive,
147 buffers: &[Data],
148 name: &str,
149) -> Result<Mesh3D, GltfError> {
150 let reader = primitive.reader(|buffer| Some(&buffers[buffer.index()]));
151
152 let positions: Vec<[f32; 3]> = reader
154 .read_positions()
155 .ok_or(GltfError::MissingPositions)?
156 .collect();
157
158 let normals: Vec<[f32; 3]> = reader
160 .read_normals()
161 .map(|iter| iter.collect())
162 .unwrap_or_else(|| vec![[0.0, 1.0, 0.0]; positions.len()]);
163
164 let uvs: Vec<[f32; 2]> = reader
166 .read_tex_coords(0)
167 .map(|tex| tex.into_f32().collect())
168 .unwrap_or_else(|| vec![[0.0, 0.0]; positions.len()]);
169
170 let vertices: Vec<Vertex3D> = positions
172 .iter()
173 .zip(normals.iter())
174 .zip(uvs.iter())
175 .map(|((&pos, &normal), &uv)| Vertex3D {
176 position: pos,
177 normal,
178 uv,
179 })
180 .collect();
181
182 let indices: Vec<u16> = reader
184 .read_indices()
185 .map(|indices| indices.into_u32().map(|i| i as u16).collect())
186 .unwrap_or_else(|| (0..vertices.len() as u16).collect());
187
188 Ok(Mesh3D::create(device, &vertices, &indices, Some(name)))
190}