Skip to main content

gltforge_unity/
convert.rs

1use crate::{
2    error::{ConvertError, ConvertResult},
3    gltf_primitive_data::GltfPrimitiveData,
4    unity_gltf::UnityGltf,
5};
6
7use gltforge::{
8    parser::resolve_accessor,
9    schema::{
10        AccessorComponentType, AccessorType, Gltf, MaterialAlphaMode, MeshPrimitive,
11        MeshPrimitiveMode,
12    },
13};
14use gltforge_unity_core::{
15    ALPHA_MODE_BLEND, ALPHA_MODE_MASK, ALPHA_MODE_OPAQUE, UnityGameObject, UnityImage,
16    UnityIndices, UnityMesh, UnityPbrMetallicRoughness, UnitySubMesh, UnityTransform,
17};
18
19use error_location::ErrorLocation;
20use std::{collections::HashMap, panic::Location};
21
22/// Build a [`UnityGltf`] from a parsed glTF document and its loaded buffers.
23///
24/// Converts all nodes and meshes into Unity's left-handed coordinate system.
25/// Each glTF mesh becomes one [`UnityMesh`] with a merged vertex array and one
26/// sub-mesh per glTF primitive, with indices pre-offset into the shared array.
27#[track_caller]
28pub fn build_unity_gltf(
29    gltf: &Gltf,
30    buffers: &[Vec<u8>],
31    file_stem: &str,
32) -> ConvertResult<UnityGltf> {
33    // ---- scene --------------------------------------------------------------
34
35    let scene_idx = gltf.scene.unwrap_or(0) as usize;
36    let scene = gltf.scenes.as_deref().and_then(|s| s.get(scene_idx));
37
38    let scene_name = scene
39        .and_then(|s| s.name.clone())
40        .unwrap_or_else(|| file_stem.to_string());
41    let root_game_objects = scene
42        .and_then(|s| s.nodes.as_deref())
43        .unwrap_or(&[])
44        .to_vec();
45
46    // ---- game objects -------------------------------------------------------
47
48    let mut game_objects: HashMap<u32, UnityGameObject> = HashMap::new();
49
50    for (idx, node) in gltf.nodes.as_deref().unwrap_or(&[]).iter().enumerate() {
51        let children = node.children.as_deref().unwrap_or(&[]).to_vec();
52        let mesh_indices = node.mesh.map(|m| vec![m]).unwrap_or_default();
53
54        game_objects.insert(
55            idx as u32,
56            UnityGameObject {
57                name: node.name.clone().unwrap_or_else(|| idx.to_string()),
58                children,
59                mesh_indices,
60                transform: build_transform(node),
61            },
62        );
63    }
64
65    // ---- images -------------------------------------------------------------
66
67    let bvs = gltf.buffer_views.as_deref().unwrap_or(&[]);
68
69    let images: HashMap<u32, UnityImage> = gltf
70        .images
71        .as_deref()
72        .unwrap_or(&[])
73        .iter()
74        .enumerate()
75        .map(|(idx, img)| {
76            let bytes = img.buffer_view.and_then(|bv_idx| {
77                let bv = bvs.get(bv_idx as usize)?;
78                let buf = buffers.get(bv.buffer as usize)?;
79                let start = bv.byte_offset as usize;
80                let end = start + bv.byte_length as usize;
81                buf.get(start..end).map(|s| s.to_vec())
82            });
83            (
84                idx as u32,
85                UnityImage {
86                    name: img.name.clone().unwrap_or_else(|| idx.to_string()),
87                    uri: img.uri.clone(),
88                    bytes,
89                },
90            )
91        })
92        .collect();
93
94    // ---- PBR materials ------------------------------------------------------
95
96    let schema_textures = gltf.textures.as_deref().unwrap_or(&[]);
97
98    // Resolve a TextureInfo index (texture index) to an image index.
99    let resolve_image =
100        |tex_idx: u32| -> Option<u32> { schema_textures.get(tex_idx as usize)?.source };
101
102    let pbr_metallic_roughness: HashMap<u32, UnityPbrMetallicRoughness> = gltf
103        .materials
104        .as_deref()
105        .unwrap_or(&[])
106        .iter()
107        .enumerate()
108        .map(|(idx, mat)| {
109            let pbr = mat.pbr_metallic_roughness.as_ref();
110
111            let alpha_mode = match mat.alpha_mode {
112                MaterialAlphaMode::Opaque => ALPHA_MODE_OPAQUE,
113                MaterialAlphaMode::Mask => ALPHA_MODE_MASK,
114                MaterialAlphaMode::Blend => ALPHA_MODE_BLEND,
115            };
116
117            (
118                idx as u32,
119                UnityPbrMetallicRoughness {
120                    name: mat.name.clone().unwrap_or_else(|| idx.to_string()),
121
122                    base_color_texture: pbr
123                        .and_then(|p| p.base_color_texture.as_ref())
124                        .and_then(|t| resolve_image(t.index)),
125
126                    metallic_roughness_texture: pbr
127                        .and_then(|p| p.metallic_roughness_texture.as_ref())
128                        .and_then(|t| resolve_image(t.index)),
129
130                    normal_texture: mat
131                        .normal_texture
132                        .as_ref()
133                        .and_then(|t| resolve_image(t.index)),
134
135                    occlusion_texture: mat
136                        .occlusion_texture
137                        .as_ref()
138                        .and_then(|t| resolve_image(t.index)),
139
140                    emissive_texture: mat
141                        .emissive_texture
142                        .as_ref()
143                        .and_then(|t| resolve_image(t.index)),
144
145                    base_color_factor: pbr
146                        .map(|p| p.base_color_factor)
147                        .unwrap_or([1.0, 1.0, 1.0, 1.0]),
148
149                    metallic_factor: pbr.map(|p| p.metallic_factor).unwrap_or(1.0),
150
151                    roughness_factor: pbr.map(|p| p.roughness_factor).unwrap_or(1.0),
152
153                    normal_scale: mat.normal_texture.as_ref().map(|t| t.scale).unwrap_or(1.0),
154
155                    occlusion_strength: mat
156                        .occlusion_texture
157                        .as_ref()
158                        .map(|t| t.strength)
159                        .unwrap_or(1.0),
160
161                    emissive_factor: mat.emissive_factor,
162
163                    alpha_cutoff: mat.alpha_cutoff,
164                    alpha_mode,
165                    double_sided: mat.double_sided,
166                },
167            )
168        })
169        .collect();
170
171    // ---- meshes -------------------------------------------------------------
172
173    let accessors = gltf.accessors.as_deref().unwrap_or(&[]);
174    let mut meshes: HashMap<u32, UnityMesh> = HashMap::new();
175
176    for (mesh_idx, mesh) in gltf.meshes.as_deref().unwrap_or(&[]).iter().enumerate() {
177        // --- First pass: resolve each primitive's positions and wound indices ---
178
179        let mut prims: Vec<GltfPrimitiveData> = Vec::new();
180
181        for prim in &mesh.primitives {
182            prims.push(resolve_primitive(prim, accessors, bvs, buffers)?);
183        }
184
185        // --- Second pass: merge vertices, offset indices, pick format --------
186
187        let total_verts: usize = prims.iter().map(|p| p.positions.len()).sum();
188        let use_u32 = total_verts > 65535;
189        let all_have_normals = prims.iter().all(|p| p.normals.is_some());
190        let all_have_tangents = prims.iter().all(|p| p.tangents.is_some());
191
192        // Determine how many UV channels are shared by every primitive (stop at first gap).
193        let max_uv_channels = prims.iter().map(|p| p.uvs.len()).min().unwrap_or(0);
194        let uv_channel_count = (0..max_uv_channels)
195            .take_while(|&ch| prims.iter().all(|p| p.uvs[ch].is_some()))
196            .count();
197
198        let mut vertices: Vec<[f32; 3]> = Vec::with_capacity(total_verts);
199        let mut normals: Vec<[f32; 3]> = if all_have_normals {
200            Vec::with_capacity(total_verts)
201        } else {
202            Vec::new()
203        };
204        let mut tangents: Vec<[f32; 4]> = if all_have_tangents {
205            Vec::with_capacity(total_verts)
206        } else {
207            Vec::new()
208        };
209        let mut uvs: Vec<Vec<[f32; 2]>> = (0..uv_channel_count)
210            .map(|_| Vec::with_capacity(total_verts))
211            .collect();
212        let mut sub_meshes: Vec<UnitySubMesh> = Vec::with_capacity(prims.len());
213
214        for (
215            prim_idx,
216            GltfPrimitiveData {
217                positions,
218                normals: prim_norms,
219                tangents: prim_tangs,
220                uvs: prim_uvs,
221                wound,
222            },
223        ) in prims.into_iter().enumerate()
224        {
225            let base = vertices.len() as u32;
226            vertices.extend_from_slice(&positions);
227
228            if let Some(n) = prim_norms {
229                normals.extend_from_slice(&n);
230            }
231
232            if let Some(t) = prim_tangs {
233                tangents.extend_from_slice(&t);
234            }
235
236            for (ch, ch_buf) in uvs.iter_mut().enumerate() {
237                if let Some(Some(ch_uvs)) = prim_uvs.get(ch) {
238                    ch_buf.extend_from_slice(ch_uvs);
239                }
240            }
241
242            let indices = if use_u32 {
243                UnityIndices::U32(wound.into_iter().map(|i| i + base).collect())
244            } else {
245                UnityIndices::U16(wound.into_iter().map(|i| (i + base) as u16).collect())
246            };
247
248            let material_index = mesh.primitives[prim_idx].material;
249
250            sub_meshes.push(UnitySubMesh {
251                indices,
252                material_index,
253            });
254        }
255
256        let name = mesh.name.clone().unwrap_or_else(|| mesh_idx.to_string());
257
258        meshes.insert(
259            mesh_idx as u32,
260            UnityMesh {
261                name,
262                vertices,
263                normals,
264                tangents,
265                uvs,
266                sub_meshes,
267            },
268        );
269    }
270
271    Ok(UnityGltf {
272        scene_name,
273        root_game_objects,
274        game_objects,
275        meshes,
276        images,
277        pbr_metallic_roughness,
278    })
279}
280
281/// Resolve a single glTF primitive into left-handed positions, optional normals, and wound indices.
282#[track_caller]
283fn resolve_primitive(
284    prim: &MeshPrimitive,
285    accessors: &[gltforge::schema::Accessor],
286    bvs: &[gltforge::schema::BufferView],
287    buffers: &[Vec<u8>],
288) -> ConvertResult<GltfPrimitiveData> {
289    if prim.mode != MeshPrimitiveMode::Triangles {
290        return Err(ConvertError::UnsupportedPrimitiveMode {
291            mode: prim.mode,
292            location: ErrorLocation::from(Location::caller()),
293        });
294    }
295
296    // ---- positions ----------------------------------------------------------
297
298    let pos_id =
299        *prim
300            .attributes
301            .get("POSITION")
302            .ok_or_else(|| ConvertError::NoPositionAttribute {
303                location: ErrorLocation::from(Location::caller()),
304            })? as usize;
305
306    let pos_acc =
307        accessors
308            .get(pos_id)
309            .ok_or_else(|| ConvertError::PositionAccessorOutOfRange {
310                location: ErrorLocation::from(Location::caller()),
311            })?;
312
313    if pos_acc.accessor_type != AccessorType::Vec3
314        || pos_acc.component_type != AccessorComponentType::Float
315    {
316        return Err(ConvertError::InvalidPositionType {
317            location: ErrorLocation::from(Location::caller()),
318        });
319    }
320
321    let pos_bytes = resolve_accessor(pos_acc, bvs, buffers).map_err(|e| ConvertError::Resolve {
322        source: e,
323        location: ErrorLocation::from(Location::caller()),
324    })?;
325
326    let positions: Vec<[f32; 3]> = pos_bytes
327        .chunks_exact(12)
328        .map(|c| {
329            let x = f32::from_le_bytes([c[0], c[1], c[2], c[3]]);
330            let y = f32::from_le_bytes([c[4], c[5], c[6], c[7]]);
331            let z = f32::from_le_bytes([c[8], c[9], c[10], c[11]]);
332            [-x, y, z]
333        })
334        .collect();
335
336    // ---- indices ------------------------------------------------------------
337
338    let idx_id = prim.indices.ok_or_else(|| ConvertError::NoIndices {
339        location: ErrorLocation::from(Location::caller()),
340    })?;
341
342    let idx_acc =
343        accessors
344            .get(idx_id as usize)
345            .ok_or_else(|| ConvertError::IndexAccessorOutOfRange {
346                location: ErrorLocation::from(Location::caller()),
347            })?;
348
349    let idx_bytes = resolve_accessor(idx_acc, bvs, buffers).map_err(|e| ConvertError::Resolve {
350        source: e,
351        location: ErrorLocation::from(Location::caller()),
352    })?;
353
354    let raw = decode_indices(idx_bytes, idx_acc.component_type)?;
355
356    // Reverse winding order (glTF right-handed → Unity left-handed).
357    let wound: Vec<u32> = raw
358        .chunks_exact(3)
359        .flat_map(|tri| [tri[0], tri[2], tri[1]])
360        .collect();
361
362    // ---- normals (optional) -------------------------------------------------
363
364    let normals = if let Some(&norm_id) = prim.attributes.get("NORMAL") {
365        let norm_acc = accessors.get(norm_id as usize).ok_or_else(|| {
366            ConvertError::PositionAccessorOutOfRange {
367                location: ErrorLocation::from(Location::caller()),
368            }
369        })?;
370
371        let norm_bytes =
372            resolve_accessor(norm_acc, bvs, buffers).map_err(|e| ConvertError::Resolve {
373                source: e,
374                location: ErrorLocation::from(Location::caller()),
375            })?;
376
377        Some(
378            norm_bytes
379                .chunks_exact(12)
380                .map(|c| {
381                    let x = f32::from_le_bytes([c[0], c[1], c[2], c[3]]);
382                    let y = f32::from_le_bytes([c[4], c[5], c[6], c[7]]);
383                    let z = f32::from_le_bytes([c[8], c[9], c[10], c[11]]);
384                    [-x, y, z]
385                })
386                .collect(),
387        )
388    } else {
389        None
390    };
391
392    // ---- tangents (optional) ------------------------------------------------
393
394    let tangents = if let Some(&tang_id) = prim.attributes.get("TANGENT") {
395        let tang_acc = accessors.get(tang_id as usize).ok_or_else(|| {
396            ConvertError::PositionAccessorOutOfRange {
397                location: ErrorLocation::from(Location::caller()),
398            }
399        })?;
400
401        let tang_bytes =
402            resolve_accessor(tang_acc, bvs, buffers).map_err(|e| ConvertError::Resolve {
403                source: e,
404                location: ErrorLocation::from(Location::caller()),
405            })?;
406
407        Some(
408            tang_bytes
409                .chunks_exact(16)
410                .map(|c| {
411                    let x = f32::from_le_bytes([c[0], c[1], c[2], c[3]]);
412                    let y = f32::from_le_bytes([c[4], c[5], c[6], c[7]]);
413                    let z = f32::from_le_bytes([c[8], c[9], c[10], c[11]]);
414                    let w = f32::from_le_bytes([c[12], c[13], c[14], c[15]]);
415                    // Negate X (coordinate flip) and W (bitangent handedness flip).
416                    [-x, y, z, -w]
417                })
418                .collect(),
419        )
420    } else {
421        None
422    };
423
424    // ---- UV channels (optional, TEXCOORD_0 … TEXCOORD_7) -------------------
425
426    let mut uvs: Vec<Option<Vec<[f32; 2]>>> = Vec::new();
427    for ch in 0u32..8 {
428        let key = format!("TEXCOORD_{ch}");
429        let Some(&uv_id) = prim.attributes.get(&key) else {
430            break; // Stop at the first absent channel.
431        };
432
433        let uv_acc = accessors.get(uv_id as usize).ok_or_else(|| {
434            ConvertError::PositionAccessorOutOfRange {
435                location: ErrorLocation::from(Location::caller()),
436            }
437        })?;
438
439        let uv_bytes =
440            resolve_accessor(uv_acc, bvs, buffers).map_err(|e| ConvertError::Resolve {
441                source: e,
442                location: ErrorLocation::from(Location::caller()),
443            })?;
444
445        let channel: Vec<[f32; 2]> = uv_bytes
446            .chunks_exact(8)
447            .map(|c| {
448                let u = f32::from_le_bytes([c[0], c[1], c[2], c[3]]);
449                let v = f32::from_le_bytes([c[4], c[5], c[6], c[7]]);
450                // Flip V: glTF origin is top-left, Unity origin is bottom-left.
451                [u, 1.0 - v]
452            })
453            .collect();
454
455        uvs.push(Some(channel));
456    }
457
458    Ok(GltfPrimitiveData {
459        positions,
460        normals,
461        tangents,
462        uvs,
463        wound,
464    })
465}
466
467/// Build a [`UnityTransform`] from a glTF node, converting to Unity's left-handed coordinate system.
468///
469/// Handles both the `matrix` form (column-major 4×4, decomposed into TRS) and the
470/// separate `translation`/`rotation`/`scale` properties. Missing components default to identity.
471fn build_transform(node: &gltforge::schema::Node) -> UnityTransform {
472    if let Some(m) = &node.matrix {
473        mat4_to_transform(m)
474    } else {
475        let position = node
476            .translation
477            .map(|t| [-t[0], t[1], t[2]])
478            .unwrap_or([0.0, 0.0, 0.0]);
479        let rotation = node
480            .rotation
481            .map(|r| [-r[0], r[1], r[2], -r[3]])
482            .unwrap_or([0.0, 0.0, 0.0, 1.0]);
483        let scale = node.scale.unwrap_or([1.0, 1.0, 1.0]);
484        UnityTransform {
485            position,
486            rotation,
487            scale,
488        }
489    }
490}
491
492/// Decompose a glTF column-major 4×4 matrix into TRS and convert to Unity left-handed space.
493fn mat4_to_transform(m: &[f32; 16]) -> UnityTransform {
494    // glTF matrix is column-major: column k starts at index k*4.
495    // Translation is the last column.
496    let tx = m[12];
497    let ty = m[13];
498    let tz = m[14];
499
500    // Scale = length of each basis column (columns 0, 1, 2).
501    let sx = (m[0] * m[0] + m[1] * m[1] + m[2] * m[2]).sqrt();
502    let sy = (m[4] * m[4] + m[5] * m[5] + m[6] * m[6]).sqrt();
503    let sz = (m[8] * m[8] + m[9] * m[9] + m[10] * m[10]).sqrt();
504
505    // Rotation matrix: normalize each basis column.
506    // Row-major indexing: r[row][col] = m[col*4 + row] / scale[col]
507    let r00 = m[0] / sx;
508    let r10 = m[1] / sx;
509    let r20 = m[2] / sx;
510    let r01 = m[4] / sy;
511    let r11 = m[5] / sy;
512    let r21 = m[6] / sy;
513    let r02 = m[8] / sz;
514    let r12 = m[9] / sz;
515    let r22 = m[10] / sz;
516
517    let [qx, qy, qz, qw] = rot_mat_to_quat([r00, r01, r02, r10, r11, r12, r20, r21, r22]);
518
519    UnityTransform {
520        position: [-tx, ty, tz],
521        rotation: [-qx, qy, qz, -qw],
522        scale: [sx, sy, sz],
523    }
524}
525
526/// Convert a 3×3 rotation matrix (row-major, packed as `[r00,r01,r02, r10,r11,r12, r20,r21,r22]`)
527/// to a unit quaternion `[x, y, z, w]`.
528///
529/// Uses Shepperd's method for numerical stability.
530fn rot_mat_to_quat([r00, r01, r02, r10, r11, r12, r20, r21, r22]: [f32; 9]) -> [f32; 4] {
531    let trace = r00 + r11 + r22;
532
533    if trace > 0.0 {
534        let s = (trace + 1.0).sqrt() * 2.0; // s = 4w
535        let w = 0.25 * s;
536        let x = (r21 - r12) / s;
537        let y = (r02 - r20) / s;
538        let z = (r10 - r01) / s;
539        [x, y, z, w]
540    } else if r00 > r11 && r00 > r22 {
541        let s = (1.0 + r00 - r11 - r22).sqrt() * 2.0; // s = 4x
542        let w = (r21 - r12) / s;
543        let x = 0.25 * s;
544        let y = (r01 + r10) / s;
545        let z = (r02 + r20) / s;
546        [x, y, z, w]
547    } else if r11 > r22 {
548        let s = (1.0 + r11 - r00 - r22).sqrt() * 2.0; // s = 4y
549        let w = (r02 - r20) / s;
550        let x = (r01 + r10) / s;
551        let y = 0.25 * s;
552        let z = (r12 + r21) / s;
553        [x, y, z, w]
554    } else {
555        let s = (1.0 + r22 - r00 - r11).sqrt() * 2.0; // s = 4z
556        let w = (r10 - r01) / s;
557        let x = (r02 + r20) / s;
558        let y = (r12 + r21) / s;
559        let z = 0.25 * s;
560        [x, y, z, w]
561    }
562}
563
564/// Decode raw index bytes into a flat `Vec<u32>` regardless of source format.
565#[track_caller]
566fn decode_indices(bytes: &[u8], component_type: AccessorComponentType) -> ConvertResult<Vec<u32>> {
567    match component_type {
568        AccessorComponentType::UnsignedByte => Ok(bytes.iter().map(|&b| b as u32).collect()),
569        AccessorComponentType::UnsignedShort => Ok(bytes
570            .chunks_exact(2)
571            .map(|c| u16::from_le_bytes([c[0], c[1]]) as u32)
572            .collect()),
573        AccessorComponentType::UnsignedInt => Ok(bytes
574            .chunks_exact(4)
575            .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
576            .collect()),
577        other => Err(ConvertError::UnsupportedIndexComponentType {
578            component_type: other,
579            location: ErrorLocation::from(Location::caller()),
580        }),
581    }
582}