Skip to main content

gltforge_unity/
convert.rs

1use error_location::ErrorLocation;
2use gltforge::{
3    parser::resolve_accessor,
4    schema::{
5        AccessorComponentType, AccessorType, Gltf, MaterialAlphaMode, MeshPrimitive,
6        MeshPrimitiveMode,
7    },
8};
9use std::{collections::HashMap, panic::Location};
10
11use crate::{
12    error::{ConvertError, ConvertResult},
13    gltf_primitive_data::GltfPrimitiveData,
14    unity_gltf::UnityGltf,
15    unity_image::UnityImage,
16    unity_indices::UnityIndices,
17    unity_mesh::UnityMesh,
18    unity_node::UnityNode,
19    unity_node_transform::UnityNodeTransform,
20    unity_pbr_metallic_roughness::{
21        ALPHA_MODE_BLEND, ALPHA_MODE_MASK, ALPHA_MODE_OPAQUE, UnityPbrMetallicRoughness,
22    },
23    unity_submesh::UnitySubMesh,
24};
25
26/// Build a [`UnityGltf`] from a parsed glTF document and its loaded buffers.
27///
28/// Converts all nodes and meshes into Unity's left-handed coordinate system.
29/// Each glTF mesh becomes one [`UnityMesh`] with a merged vertex array and one
30/// sub-mesh per glTF primitive, with indices pre-offset into the shared array.
31#[track_caller]
32pub fn build_unity_gltf(
33    gltf: &Gltf,
34    buffers: &[Vec<u8>],
35    file_stem: &str,
36) -> ConvertResult<UnityGltf> {
37    // ---- scene --------------------------------------------------------------
38
39    let scene_idx = gltf.scene.unwrap_or(0) as usize;
40    let scene = gltf.scenes.as_deref().and_then(|s| s.get(scene_idx));
41
42    let scene_name = scene
43        .and_then(|s| s.name.clone())
44        .unwrap_or_else(|| file_stem.to_string());
45    let root_nodes = scene
46        .and_then(|s| s.nodes.as_deref())
47        .unwrap_or(&[])
48        .to_vec();
49
50    // ---- nodes --------------------------------------------------------------
51
52    let mut nodes: HashMap<u32, UnityNode> = HashMap::new();
53
54    for (idx, node) in gltf.nodes.as_deref().unwrap_or(&[]).iter().enumerate() {
55        let children = node.children.as_deref().unwrap_or(&[]).to_vec();
56        let mesh_indices = node.mesh.map(|m| vec![m]).unwrap_or_default();
57
58        nodes.insert(
59            idx as u32,
60            UnityNode {
61                name: node.name.clone().unwrap_or_else(|| idx.to_string()),
62                children,
63                mesh_indices,
64                transform: node_transform(node),
65            },
66        );
67    }
68
69    // ---- images -------------------------------------------------------------
70
71    let images: HashMap<u32, UnityImage> = gltf
72        .images
73        .as_deref()
74        .unwrap_or(&[])
75        .iter()
76        .enumerate()
77        .map(|(idx, img)| {
78            (
79                idx as u32,
80                UnityImage {
81                    name: img.name.clone().unwrap_or_else(|| idx.to_string()),
82                    uri: img.uri.clone(),
83                },
84            )
85        })
86        .collect();
87
88    // ---- PBR materials ------------------------------------------------------
89
90    let schema_textures = gltf.textures.as_deref().unwrap_or(&[]);
91
92    // Resolve a TextureInfo index (texture index) to an image index.
93    let resolve_image =
94        |tex_idx: u32| -> Option<u32> { schema_textures.get(tex_idx as usize)?.source };
95
96    let pbr_metallic_roughness: HashMap<u32, UnityPbrMetallicRoughness> = gltf
97        .materials
98        .as_deref()
99        .unwrap_or(&[])
100        .iter()
101        .enumerate()
102        .map(|(idx, mat)| {
103            let pbr = mat.pbr_metallic_roughness.as_ref();
104
105            let alpha_mode = match mat.alpha_mode {
106                MaterialAlphaMode::Opaque => ALPHA_MODE_OPAQUE,
107                MaterialAlphaMode::Mask => ALPHA_MODE_MASK,
108                MaterialAlphaMode::Blend => ALPHA_MODE_BLEND,
109            };
110
111            (
112                idx as u32,
113                UnityPbrMetallicRoughness {
114                    name: mat.name.clone().unwrap_or_else(|| idx.to_string()),
115
116                    base_color_texture: pbr
117                        .and_then(|p| p.base_color_texture.as_ref())
118                        .and_then(|t| resolve_image(t.index)),
119
120                    metallic_roughness_texture: pbr
121                        .and_then(|p| p.metallic_roughness_texture.as_ref())
122                        .and_then(|t| resolve_image(t.index)),
123
124                    normal_texture: mat
125                        .normal_texture
126                        .as_ref()
127                        .and_then(|t| resolve_image(t.index)),
128
129                    occlusion_texture: mat
130                        .occlusion_texture
131                        .as_ref()
132                        .and_then(|t| resolve_image(t.index)),
133
134                    emissive_texture: mat
135                        .emissive_texture
136                        .as_ref()
137                        .and_then(|t| resolve_image(t.index)),
138
139                    base_color_factor: pbr
140                        .map(|p| p.base_color_factor)
141                        .unwrap_or([1.0, 1.0, 1.0, 1.0]),
142
143                    metallic_factor: pbr.map(|p| p.metallic_factor).unwrap_or(1.0),
144
145                    roughness_factor: pbr.map(|p| p.roughness_factor).unwrap_or(1.0),
146
147                    normal_scale: mat.normal_texture.as_ref().map(|t| t.scale).unwrap_or(1.0),
148
149                    occlusion_strength: mat
150                        .occlusion_texture
151                        .as_ref()
152                        .map(|t| t.strength)
153                        .unwrap_or(1.0),
154
155                    emissive_factor: mat.emissive_factor,
156
157                    alpha_cutoff: mat.alpha_cutoff,
158                    alpha_mode,
159                    double_sided: mat.double_sided,
160                },
161            )
162        })
163        .collect();
164
165    // ---- meshes -------------------------------------------------------------
166
167    let bvs = gltf.buffer_views.as_deref().unwrap_or(&[]);
168    let accessors = gltf.accessors.as_deref().unwrap_or(&[]);
169    let mut meshes: HashMap<u32, UnityMesh> = HashMap::new();
170
171    for (mesh_idx, mesh) in gltf.meshes.as_deref().unwrap_or(&[]).iter().enumerate() {
172        // --- First pass: resolve each primitive's positions and wound indices ---
173
174        let mut prims: Vec<GltfPrimitiveData> = Vec::new();
175
176        for prim in &mesh.primitives {
177            prims.push(resolve_primitive(prim, accessors, bvs, buffers)?);
178        }
179
180        // --- Second pass: merge vertices, offset indices, pick format --------
181
182        let total_verts: usize = prims.iter().map(|p| p.positions.len()).sum();
183        let use_u32 = total_verts > 65535;
184        let all_have_normals = prims.iter().all(|p| p.normals.is_some());
185        let all_have_tangents = prims.iter().all(|p| p.tangents.is_some());
186
187        // Determine how many UV channels are shared by every primitive (stop at first gap).
188        let max_uv_channels = prims.iter().map(|p| p.uvs.len()).min().unwrap_or(0);
189        let uv_channel_count = (0..max_uv_channels)
190            .take_while(|&ch| prims.iter().all(|p| p.uvs[ch].is_some()))
191            .count();
192
193        let mut vertices: Vec<[f32; 3]> = Vec::with_capacity(total_verts);
194        let mut normals: Vec<[f32; 3]> = if all_have_normals {
195            Vec::with_capacity(total_verts)
196        } else {
197            Vec::new()
198        };
199        let mut tangents: Vec<[f32; 4]> = if all_have_tangents {
200            Vec::with_capacity(total_verts)
201        } else {
202            Vec::new()
203        };
204        let mut uvs: Vec<Vec<[f32; 2]>> = (0..uv_channel_count)
205            .map(|_| Vec::with_capacity(total_verts))
206            .collect();
207        let mut sub_meshes: Vec<UnitySubMesh> = Vec::with_capacity(prims.len());
208
209        for (
210            prim_idx,
211            GltfPrimitiveData {
212                positions,
213                normals: prim_norms,
214                tangents: prim_tangs,
215                uvs: prim_uvs,
216                wound,
217            },
218        ) in prims.into_iter().enumerate()
219        {
220            let base = vertices.len() as u32;
221            vertices.extend_from_slice(&positions);
222
223            if let Some(n) = prim_norms {
224                normals.extend_from_slice(&n);
225            }
226
227            if let Some(t) = prim_tangs {
228                tangents.extend_from_slice(&t);
229            }
230
231            for (ch, ch_buf) in uvs.iter_mut().enumerate() {
232                if let Some(Some(ch_uvs)) = prim_uvs.get(ch) {
233                    ch_buf.extend_from_slice(ch_uvs);
234                }
235            }
236
237            let indices = if use_u32 {
238                UnityIndices::U32(wound.into_iter().map(|i| i + base).collect())
239            } else {
240                UnityIndices::U16(wound.into_iter().map(|i| (i + base) as u16).collect())
241            };
242
243            let material_index = mesh.primitives[prim_idx].material;
244
245            sub_meshes.push(UnitySubMesh {
246                indices,
247                material_index,
248            });
249        }
250
251        let name = mesh.name.clone().unwrap_or_else(|| mesh_idx.to_string());
252
253        meshes.insert(
254            mesh_idx as u32,
255            UnityMesh {
256                name,
257                vertices,
258                normals,
259                tangents,
260                uvs,
261                sub_meshes,
262            },
263        );
264    }
265
266    Ok(UnityGltf {
267        scene_name,
268        root_nodes,
269        nodes,
270        meshes,
271        images,
272        pbr_metallic_roughness,
273    })
274}
275
276/// Resolve a single glTF primitive into left-handed positions, optional normals, and wound indices.
277#[track_caller]
278fn resolve_primitive(
279    prim: &MeshPrimitive,
280    accessors: &[gltforge::schema::Accessor],
281    bvs: &[gltforge::schema::BufferView],
282    buffers: &[Vec<u8>],
283) -> ConvertResult<GltfPrimitiveData> {
284    if prim.mode != MeshPrimitiveMode::Triangles {
285        return Err(ConvertError::UnsupportedPrimitiveMode {
286            mode: prim.mode,
287            location: ErrorLocation::from(Location::caller()),
288        });
289    }
290
291    // ---- positions ----------------------------------------------------------
292
293    let pos_id =
294        *prim
295            .attributes
296            .get("POSITION")
297            .ok_or_else(|| ConvertError::NoPositionAttribute {
298                location: ErrorLocation::from(Location::caller()),
299            })? as usize;
300
301    let pos_acc =
302        accessors
303            .get(pos_id)
304            .ok_or_else(|| ConvertError::PositionAccessorOutOfRange {
305                location: ErrorLocation::from(Location::caller()),
306            })?;
307
308    if pos_acc.accessor_type != AccessorType::Vec3
309        || pos_acc.component_type != AccessorComponentType::Float
310    {
311        return Err(ConvertError::InvalidPositionType {
312            location: ErrorLocation::from(Location::caller()),
313        });
314    }
315
316    let pos_bytes = resolve_accessor(pos_acc, bvs, buffers).map_err(|e| ConvertError::Resolve {
317        source: e,
318        location: ErrorLocation::from(Location::caller()),
319    })?;
320
321    let positions: Vec<[f32; 3]> = pos_bytes
322        .chunks_exact(12)
323        .map(|c| {
324            let x = f32::from_le_bytes([c[0], c[1], c[2], c[3]]);
325            let y = f32::from_le_bytes([c[4], c[5], c[6], c[7]]);
326            let z = f32::from_le_bytes([c[8], c[9], c[10], c[11]]);
327            [-x, y, z]
328        })
329        .collect();
330
331    // ---- indices ------------------------------------------------------------
332
333    let idx_id = prim.indices.ok_or_else(|| ConvertError::NoIndices {
334        location: ErrorLocation::from(Location::caller()),
335    })?;
336
337    let idx_acc =
338        accessors
339            .get(idx_id as usize)
340            .ok_or_else(|| ConvertError::IndexAccessorOutOfRange {
341                location: ErrorLocation::from(Location::caller()),
342            })?;
343
344    let idx_bytes = resolve_accessor(idx_acc, bvs, buffers).map_err(|e| ConvertError::Resolve {
345        source: e,
346        location: ErrorLocation::from(Location::caller()),
347    })?;
348
349    let raw = decode_indices(idx_bytes, idx_acc.component_type)?;
350
351    // Reverse winding order (glTF right-handed → Unity left-handed).
352    let wound: Vec<u32> = raw
353        .chunks_exact(3)
354        .flat_map(|tri| [tri[0], tri[2], tri[1]])
355        .collect();
356
357    // ---- normals (optional) -------------------------------------------------
358
359    let normals = if let Some(&norm_id) = prim.attributes.get("NORMAL") {
360        let norm_acc = accessors.get(norm_id as usize).ok_or_else(|| {
361            ConvertError::PositionAccessorOutOfRange {
362                location: ErrorLocation::from(Location::caller()),
363            }
364        })?;
365
366        let norm_bytes =
367            resolve_accessor(norm_acc, bvs, buffers).map_err(|e| ConvertError::Resolve {
368                source: e,
369                location: ErrorLocation::from(Location::caller()),
370            })?;
371
372        Some(
373            norm_bytes
374                .chunks_exact(12)
375                .map(|c| {
376                    let x = f32::from_le_bytes([c[0], c[1], c[2], c[3]]);
377                    let y = f32::from_le_bytes([c[4], c[5], c[6], c[7]]);
378                    let z = f32::from_le_bytes([c[8], c[9], c[10], c[11]]);
379                    [-x, y, z]
380                })
381                .collect(),
382        )
383    } else {
384        None
385    };
386
387    // ---- tangents (optional) ------------------------------------------------
388
389    let tangents = if let Some(&tang_id) = prim.attributes.get("TANGENT") {
390        let tang_acc = accessors.get(tang_id as usize).ok_or_else(|| {
391            ConvertError::PositionAccessorOutOfRange {
392                location: ErrorLocation::from(Location::caller()),
393            }
394        })?;
395
396        let tang_bytes =
397            resolve_accessor(tang_acc, bvs, buffers).map_err(|e| ConvertError::Resolve {
398                source: e,
399                location: ErrorLocation::from(Location::caller()),
400            })?;
401
402        Some(
403            tang_bytes
404                .chunks_exact(16)
405                .map(|c| {
406                    let x = f32::from_le_bytes([c[0], c[1], c[2], c[3]]);
407                    let y = f32::from_le_bytes([c[4], c[5], c[6], c[7]]);
408                    let z = f32::from_le_bytes([c[8], c[9], c[10], c[11]]);
409                    let w = f32::from_le_bytes([c[12], c[13], c[14], c[15]]);
410                    // Negate X (coordinate flip) and W (bitangent handedness flip).
411                    [-x, y, z, -w]
412                })
413                .collect(),
414        )
415    } else {
416        None
417    };
418
419    // ---- UV channels (optional, TEXCOORD_0 … TEXCOORD_7) -------------------
420
421    let mut uvs: Vec<Option<Vec<[f32; 2]>>> = Vec::new();
422    for ch in 0u32..8 {
423        let key = format!("TEXCOORD_{ch}");
424        let Some(&uv_id) = prim.attributes.get(&key) else {
425            break; // Stop at the first absent channel.
426        };
427
428        let uv_acc = accessors.get(uv_id as usize).ok_or_else(|| {
429            ConvertError::PositionAccessorOutOfRange {
430                location: ErrorLocation::from(Location::caller()),
431            }
432        })?;
433
434        let uv_bytes =
435            resolve_accessor(uv_acc, bvs, buffers).map_err(|e| ConvertError::Resolve {
436                source: e,
437                location: ErrorLocation::from(Location::caller()),
438            })?;
439
440        let channel: Vec<[f32; 2]> = uv_bytes
441            .chunks_exact(8)
442            .map(|c| {
443                let u = f32::from_le_bytes([c[0], c[1], c[2], c[3]]);
444                let v = f32::from_le_bytes([c[4], c[5], c[6], c[7]]);
445                // Flip V: glTF origin is top-left, Unity origin is bottom-left.
446                [u, 1.0 - v]
447            })
448            .collect();
449
450        uvs.push(Some(channel));
451    }
452
453    Ok(GltfPrimitiveData {
454        positions,
455        normals,
456        tangents,
457        uvs,
458        wound,
459    })
460}
461
462/// Build a [`UnityNodeTransform`] from a glTF node, converting to Unity's left-handed coordinate system.
463///
464/// Handles both the `matrix` form (column-major 4×4, decomposed into TRS) and the
465/// separate `translation`/`rotation`/`scale` properties. Missing components default to identity.
466fn node_transform(node: &gltforge::schema::Node) -> UnityNodeTransform {
467    if let Some(m) = &node.matrix {
468        mat4_to_node_transform(m)
469    } else {
470        let position = node
471            .translation
472            .map(|t| [-t[0], t[1], t[2]])
473            .unwrap_or([0.0, 0.0, 0.0]);
474        let rotation = node
475            .rotation
476            .map(|r| [-r[0], r[1], r[2], -r[3]])
477            .unwrap_or([0.0, 0.0, 0.0, 1.0]);
478        let scale = node.scale.unwrap_or([1.0, 1.0, 1.0]);
479        UnityNodeTransform {
480            position,
481            rotation,
482            scale,
483        }
484    }
485}
486
487/// Decompose a glTF column-major 4×4 matrix into TRS and convert to Unity left-handed space.
488fn mat4_to_node_transform(m: &[f32; 16]) -> UnityNodeTransform {
489    // glTF matrix is column-major: column k starts at index k*4.
490    // Translation is the last column.
491    let tx = m[12];
492    let ty = m[13];
493    let tz = m[14];
494
495    // Scale = length of each basis column (columns 0, 1, 2).
496    let sx = (m[0] * m[0] + m[1] * m[1] + m[2] * m[2]).sqrt();
497    let sy = (m[4] * m[4] + m[5] * m[5] + m[6] * m[6]).sqrt();
498    let sz = (m[8] * m[8] + m[9] * m[9] + m[10] * m[10]).sqrt();
499
500    // Rotation matrix: normalize each basis column.
501    // Row-major indexing: r[row][col] = m[col*4 + row] / scale[col]
502    let r00 = m[0] / sx;
503    let r10 = m[1] / sx;
504    let r20 = m[2] / sx;
505    let r01 = m[4] / sy;
506    let r11 = m[5] / sy;
507    let r21 = m[6] / sy;
508    let r02 = m[8] / sz;
509    let r12 = m[9] / sz;
510    let r22 = m[10] / sz;
511
512    let [qx, qy, qz, qw] = rot_mat_to_quat([r00, r01, r02, r10, r11, r12, r20, r21, r22]);
513
514    UnityNodeTransform {
515        position: [-tx, ty, tz],
516        rotation: [-qx, qy, qz, -qw],
517        scale: [sx, sy, sz],
518    }
519}
520
521/// Convert a 3×3 rotation matrix (row-major, packed as `[r00,r01,r02, r10,r11,r12, r20,r21,r22]`)
522/// to a unit quaternion `[x, y, z, w]`.
523///
524/// Uses Shepperd's method for numerical stability.
525fn rot_mat_to_quat([r00, r01, r02, r10, r11, r12, r20, r21, r22]: [f32; 9]) -> [f32; 4] {
526    let trace = r00 + r11 + r22;
527
528    if trace > 0.0 {
529        let s = (trace + 1.0).sqrt() * 2.0; // s = 4w
530        let w = 0.25 * s;
531        let x = (r21 - r12) / s;
532        let y = (r02 - r20) / s;
533        let z = (r10 - r01) / s;
534        [x, y, z, w]
535    } else if r00 > r11 && r00 > r22 {
536        let s = (1.0 + r00 - r11 - r22).sqrt() * 2.0; // s = 4x
537        let w = (r21 - r12) / s;
538        let x = 0.25 * s;
539        let y = (r01 + r10) / s;
540        let z = (r02 + r20) / s;
541        [x, y, z, w]
542    } else if r11 > r22 {
543        let s = (1.0 + r11 - r00 - r22).sqrt() * 2.0; // s = 4y
544        let w = (r02 - r20) / s;
545        let x = (r01 + r10) / s;
546        let y = 0.25 * s;
547        let z = (r12 + r21) / s;
548        [x, y, z, w]
549    } else {
550        let s = (1.0 + r22 - r00 - r11).sqrt() * 2.0; // s = 4z
551        let w = (r10 - r01) / s;
552        let x = (r02 + r20) / s;
553        let y = (r12 + r21) / s;
554        let z = 0.25 * s;
555        [x, y, z, w]
556    }
557}
558
559/// Decode raw index bytes into a flat `Vec<u32>` regardless of source format.
560#[track_caller]
561fn decode_indices(bytes: &[u8], component_type: AccessorComponentType) -> ConvertResult<Vec<u32>> {
562    match component_type {
563        AccessorComponentType::UnsignedByte => Ok(bytes.iter().map(|&b| b as u32).collect()),
564        AccessorComponentType::UnsignedShort => Ok(bytes
565            .chunks_exact(2)
566            .map(|c| u16::from_le_bytes([c[0], c[1]]) as u32)
567            .collect()),
568        AccessorComponentType::UnsignedInt => Ok(bytes
569            .chunks_exact(4)
570            .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
571            .collect()),
572        other => Err(ConvertError::UnsupportedIndexComponentType {
573            component_type: other,
574            location: ErrorLocation::from(Location::caller()),
575        }),
576    }
577}