Skip to main content

oxihuman_export/
scene_graph.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4use std::io::Write;
5use std::path::Path;
6
7use anyhow::Result;
8use bytemuck::cast_slice;
9use oxihuman_mesh::MeshBuffers;
10use serde_json::json;
11
12// GLB magic constants (same as glb.rs / scene.rs)
13const GLB_MAGIC: u32 = 0x46546C67; // "glTF"
14const GLB_VERSION: u32 = 2;
15const CHUNK_JSON: u32 = 0x4E4F534A; // "JSON"
16const CHUNK_BIN: u32 = 0x004E4942; // "BIN\0"
17
18// ── Transform ────────────────────────────────────────────────────────────────
19
20/// A 4×4 column-major transform matrix (identity by default).
21#[derive(Debug, Clone)]
22pub struct Transform {
23    pub matrix: [f32; 16],
24}
25
26impl Transform {
27    /// Return the 4×4 identity matrix.
28    pub fn identity() -> Self {
29        #[rustfmt::skip]
30        let m = [
31            1.0, 0.0, 0.0, 0.0,
32            0.0, 1.0, 0.0, 0.0,
33            0.0, 0.0, 1.0, 0.0,
34            0.0, 0.0, 0.0, 1.0,
35        ];
36        Self { matrix: m }
37    }
38
39    /// Translation transform: moves by (x, y, z).
40    /// Column-major layout means the translation lives in column 3 (indices 12..14).
41    pub fn translation(x: f32, y: f32, z: f32) -> Self {
42        #[rustfmt::skip]
43        let m = [
44            1.0, 0.0, 0.0, 0.0,
45            0.0, 1.0, 0.0, 0.0,
46            0.0, 0.0, 1.0, 0.0,
47              x,   y,   z, 1.0,
48        ];
49        Self { matrix: m }
50    }
51
52    /// Uniform-scale transform.
53    pub fn scale(sx: f32, sy: f32, sz: f32) -> Self {
54        #[rustfmt::skip]
55        let m = [
56             sx, 0.0, 0.0, 0.0,
57            0.0,  sy, 0.0, 0.0,
58            0.0, 0.0,  sz, 0.0,
59            0.0, 0.0, 0.0, 1.0,
60        ];
61        Self { matrix: m }
62    }
63
64    /// Matrix multiply: `self * other` (column-major 4×4).
65    pub fn compose(&self, other: &Transform) -> Transform {
66        let a = &self.matrix;
67        let b = &other.matrix;
68        let mut c = [0.0f32; 16];
69        // c[col*4 + row] = sum_k a[k*4 + row] * b[col*4 + k]
70        for col in 0..4usize {
71            for row in 0..4usize {
72                let mut s = 0.0f32;
73                for k in 0..4usize {
74                    s += a[k * 4 + row] * b[col * 4 + k];
75                }
76                c[col * 4 + row] = s;
77            }
78        }
79        Transform { matrix: c }
80    }
81
82    /// Returns true iff the matrix is (approximately) the identity.
83    pub fn is_identity(&self) -> bool {
84        let id = Self::identity();
85        self.matrix
86            .iter()
87            .zip(id.matrix.iter())
88            .all(|(a, b)| (a - b).abs() < 1e-6)
89    }
90}
91
92// ── SceneNode ─────────────────────────────────────────────────────────────────
93
94/// A node in the scene graph.
95pub struct SceneNode {
96    pub name: String,
97    pub transform: Transform,
98    pub mesh: Option<MeshBuffers>,
99    pub children: Vec<SceneNode>,
100}
101
102impl SceneNode {
103    /// Create a new node with the given name, identity transform, no mesh, no children.
104    pub fn new(name: impl Into<String>) -> Self {
105        Self {
106            name: name.into(),
107            transform: Transform::identity(),
108            mesh: None,
109            children: Vec::new(),
110        }
111    }
112
113    /// Builder-style: set the transform.
114    pub fn with_transform(mut self, t: Transform) -> Self {
115        self.transform = t;
116        self
117    }
118
119    /// Builder-style: attach a mesh.
120    pub fn with_mesh(mut self, mesh: MeshBuffers) -> Self {
121        self.mesh = Some(mesh);
122        self
123    }
124
125    /// Add a child node.
126    pub fn add_child(&mut self, child: SceneNode) {
127        self.children.push(child);
128    }
129
130    /// Count total nodes: self plus all descendants (recursive).
131    pub fn node_count(&self) -> usize {
132        1 + self.children.iter().map(|c| c.node_count()).sum::<usize>()
133    }
134
135    /// Count nodes that carry a mesh: self (if mesh is Some) plus all descendants.
136    pub fn mesh_count(&self) -> usize {
137        let self_has = if self.mesh.is_some() { 1 } else { 0 };
138        self_has + self.children.iter().map(|c| c.mesh_count()).sum::<usize>()
139    }
140
141    /// Collect all node names in depth-first (pre-order) traversal.
142    pub fn all_names(&self) -> Vec<String> {
143        let mut names = vec![self.name.clone()];
144        for child in &self.children {
145            names.extend(child.all_names());
146        }
147        names
148    }
149}
150
151// ── SceneGraph ────────────────────────────────────────────────────────────────
152
153/// A scene graph with a single root node.
154pub struct SceneGraph {
155    pub root: SceneNode,
156}
157
158impl SceneGraph {
159    /// Create a new scene graph whose root has the given name.
160    pub fn new(root_name: impl Into<String>) -> Self {
161        Self {
162            root: SceneNode::new(root_name),
163        }
164    }
165
166    /// Total node count (root + all descendants).
167    pub fn node_count(&self) -> usize {
168        self.root.node_count()
169    }
170
171    /// Total mesh count across all nodes.
172    pub fn mesh_count(&self) -> usize {
173        self.root.mesh_count()
174    }
175}
176
177// ── GLB export ────────────────────────────────────────────────────────────────
178
179/// Internal: one mesh node collected during the depth-first walk.
180#[allow(dead_code)]
181struct MeshEntry {
182    /// Flat index in the `gltf_nodes` array.
183    gltf_node_idx: usize,
184    name: String,
185    transform: Transform,
186    mesh: MeshBuffers,
187}
188
189/// Internal: a GLTF node record (may or may not reference a mesh).
190struct GltfNodeRecord {
191    name: String,
192    transform: Transform,
193    /// Index into the GLTF meshes array (only for nodes with a mesh).
194    mesh_gltf_idx: Option<usize>,
195    /// Indices into the gltf_nodes array for children.
196    children: Vec<usize>,
197}
198
199/// Walk the scene graph depth-first, assigning each SceneNode a GLTF node index,
200/// filling `gltf_records` and `mesh_entries`.
201fn walk(
202    node: &SceneNode,
203    gltf_records: &mut Vec<GltfNodeRecord>,
204    mesh_entries: &mut Vec<MeshEntry>,
205) -> usize {
206    let my_idx = gltf_records.len();
207    // Reserve slot; we'll fill children after recursion.
208    gltf_records.push(GltfNodeRecord {
209        name: node.name.clone(),
210        transform: node.transform.clone(),
211        mesh_gltf_idx: None,
212        children: Vec::new(),
213    });
214
215    // Recurse into children first (DFS pre-order: index self, then children).
216    let mut child_indices = Vec::new();
217    for child in &node.children {
218        let child_idx = walk(child, gltf_records, mesh_entries);
219        child_indices.push(child_idx);
220    }
221    gltf_records[my_idx].children = child_indices;
222
223    // If this node has a mesh, we'll resolve mesh_gltf_idx after the full walk
224    // when we know how many mesh entries there are.  Store the entry index for now.
225    if let Some(mesh) = node.mesh.clone() {
226        let entry_idx = mesh_entries.len(); // this will be the GLTF mesh index
227        mesh_entries.push(MeshEntry {
228            gltf_node_idx: my_idx,
229            name: node.name.clone(),
230            transform: node.transform.clone(),
231            mesh,
232        });
233        gltf_records[my_idx].mesh_gltf_idx = Some(entry_idx);
234    }
235
236    my_idx
237}
238
239/// Per-mesh BIN layout info.
240struct MeshBinLayout {
241    pos_offset: usize,
242    norm_offset: usize,
243    uv_offset: usize,
244    idx_offset: usize,
245    n_verts: usize,
246    n_idx: usize,
247    pos_bytes_len: usize,
248    norm_bytes_len: usize,
249    uv_bytes_len: usize,
250    idx_bytes_len: usize,
251}
252
253/// Export a scene graph to a GLB 2.0 file.
254///
255/// Each mesh node becomes a GLTF node with its transform applied.
256/// Child nodes become child nodes in the GLTF node hierarchy.
257/// Only nodes with meshes produce GLTF mesh entries.
258pub fn export_scene_graph_glb(graph: &SceneGraph, path: &Path) -> Result<()> {
259    // ── 1. Walk the scene graph ───────────────────────────────────────────────
260    let mut gltf_records: Vec<GltfNodeRecord> = Vec::new();
261    let mut mesh_entries: Vec<MeshEntry> = Vec::new();
262
263    let root_idx = walk(&graph.root, &mut gltf_records, &mut mesh_entries);
264
265    // ── 2. Build BIN chunk from mesh entries ─────────────────────────────────
266    let mut bin_data: Vec<u8> = Vec::new();
267    let mut bin_layouts: Vec<MeshBinLayout> = Vec::new();
268
269    for entry in &mesh_entries {
270        let mesh = &entry.mesh;
271        let pos_bytes: &[u8] = cast_slice(&mesh.positions);
272        let norm_bytes: &[u8] = cast_slice(&mesh.normals);
273        let uv_bytes: &[u8] = cast_slice(&mesh.uvs);
274        let idx_bytes: &[u8] = cast_slice(&mesh.indices);
275
276        let pos_offset = bin_data.len();
277        bin_data.extend_from_slice(pos_bytes);
278        let norm_offset = bin_data.len();
279        bin_data.extend_from_slice(norm_bytes);
280        let uv_offset = bin_data.len();
281        bin_data.extend_from_slice(uv_bytes);
282        let idx_offset = bin_data.len();
283        bin_data.extend_from_slice(idx_bytes);
284
285        bin_layouts.push(MeshBinLayout {
286            pos_offset,
287            norm_offset,
288            uv_offset,
289            idx_offset,
290            n_verts: mesh.positions.len(),
291            n_idx: mesh.indices.len(),
292            pos_bytes_len: pos_bytes.len(),
293            norm_bytes_len: norm_bytes.len(),
294            uv_bytes_len: uv_bytes.len(),
295            idx_bytes_len: idx_bytes.len(),
296        });
297    }
298
299    // Pad BIN to 4-byte boundary
300    while !bin_data.len().is_multiple_of(4) {
301        bin_data.push(0x00);
302    }
303
304    // ── 3. Build accessors, bufferViews, meshes JSON ─────────────────────────
305    let mut accessors: Vec<serde_json::Value> = Vec::new();
306    let mut buffer_views: Vec<serde_json::Value> = Vec::new();
307    let mut meshes_json: Vec<serde_json::Value> = Vec::new();
308
309    for (mesh_idx, (entry, layout)) in mesh_entries.iter().zip(bin_layouts.iter()).enumerate() {
310        let pos_bv_idx = buffer_views.len();
311        buffer_views.push(json!({
312            "buffer": 0,
313            "byteOffset": layout.pos_offset,
314            "byteLength": layout.pos_bytes_len
315        }));
316
317        let norm_bv_idx = buffer_views.len();
318        buffer_views.push(json!({
319            "buffer": 0,
320            "byteOffset": layout.norm_offset,
321            "byteLength": layout.norm_bytes_len
322        }));
323
324        let uv_bv_idx = buffer_views.len();
325        buffer_views.push(json!({
326            "buffer": 0,
327            "byteOffset": layout.uv_offset,
328            "byteLength": layout.uv_bytes_len
329        }));
330
331        let idx_bv_idx = buffer_views.len();
332        buffer_views.push(json!({
333            "buffer": 0,
334            "byteOffset": layout.idx_offset,
335            "byteLength": layout.idx_bytes_len
336        }));
337
338        let pos_acc_idx = accessors.len();
339        accessors.push(json!({
340            "bufferView": pos_bv_idx,
341            "componentType": 5126,
342            "count": layout.n_verts,
343            "type": "VEC3"
344        }));
345
346        let norm_acc_idx = accessors.len();
347        accessors.push(json!({
348            "bufferView": norm_bv_idx,
349            "componentType": 5126,
350            "count": layout.n_verts,
351            "type": "VEC3"
352        }));
353
354        let uv_acc_idx = accessors.len();
355        accessors.push(json!({
356            "bufferView": uv_bv_idx,
357            "componentType": 5126,
358            "count": layout.n_verts,
359            "type": "VEC2"
360        }));
361
362        let idx_acc_idx = accessors.len();
363        accessors.push(json!({
364            "bufferView": idx_bv_idx,
365            "componentType": 5125,
366            "count": layout.n_idx,
367            "type": "SCALAR"
368        }));
369
370        let _ = mesh_idx; // used implicitly via entry reference below
371
372        meshes_json.push(json!({
373            "name": entry.name,
374            "primitives": [{
375                "attributes": {
376                    "POSITION":   pos_acc_idx,
377                    "NORMAL":     norm_acc_idx,
378                    "TEXCOORD_0": uv_acc_idx
379                },
380                "indices": idx_acc_idx
381            }]
382        }));
383    }
384
385    // ── 4. Build GLTF nodes JSON ──────────────────────────────────────────────
386    let mut nodes_json: Vec<serde_json::Value> = Vec::new();
387
388    for record in &gltf_records {
389        let m = &record.transform.matrix;
390        // GLTF matrix is column-major, same as our storage: 16 floats.
391        let matrix_val: Vec<f64> = m.iter().map(|&v| v as f64).collect();
392
393        let node_val = if let Some(mesh_idx) = record.mesh_gltf_idx {
394            if record.children.is_empty() {
395                json!({
396                    "name": record.name,
397                    "matrix": matrix_val,
398                    "mesh": mesh_idx
399                })
400            } else {
401                json!({
402                    "name": record.name,
403                    "matrix": matrix_val,
404                    "mesh": mesh_idx,
405                    "children": record.children
406                })
407            }
408        } else if record.children.is_empty() {
409            json!({
410                "name": record.name,
411                "matrix": matrix_val
412            })
413        } else {
414            json!({
415                "name": record.name,
416                "matrix": matrix_val,
417                "children": record.children
418            })
419        };
420
421        nodes_json.push(node_val);
422    }
423
424    // ── 5. Build top-level GLTF JSON ─────────────────────────────────────────
425    let total_bin = bin_data.len() as u32;
426
427    let gltf = json!({
428        "asset": { "version": "2.0", "generator": "oxihuman-export/scene_graph" },
429        "scene": 0,
430        "scenes": [{ "name": graph.root.name, "nodes": [root_idx] }],
431        "nodes": nodes_json,
432        "meshes": meshes_json,
433        "accessors": accessors,
434        "bufferViews": buffer_views,
435        "buffers": [{ "byteLength": total_bin }]
436    });
437
438    let mut json_bytes = serde_json::to_vec(&gltf)?;
439    // Pad JSON to 4-byte boundary with spaces
440    while !json_bytes.len().is_multiple_of(4) {
441        json_bytes.push(b' ');
442    }
443
444    // ── 6. Write GLB ─────────────────────────────────────────────────────────
445    let json_chunk_len = json_bytes.len() as u32;
446    let bin_chunk_len = bin_data.len() as u32;
447    let total_len = 12 + 8 + json_chunk_len + 8 + bin_chunk_len;
448
449    let mut file = std::fs::File::create(path)?;
450
451    // GLB header (12 bytes)
452    file.write_all(&GLB_MAGIC.to_le_bytes())?;
453    file.write_all(&GLB_VERSION.to_le_bytes())?;
454    file.write_all(&total_len.to_le_bytes())?;
455
456    // JSON chunk
457    file.write_all(&json_chunk_len.to_le_bytes())?;
458    file.write_all(&CHUNK_JSON.to_le_bytes())?;
459    file.write_all(&json_bytes)?;
460
461    // BIN chunk
462    file.write_all(&bin_chunk_len.to_le_bytes())?;
463    file.write_all(&CHUNK_BIN.to_le_bytes())?;
464    file.write_all(&bin_data)?;
465
466    Ok(())
467}
468
469// ── Tests ─────────────────────────────────────────────────────────────────────
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474    use oxihuman_morph::engine::MeshBuffers as MB;
475
476    /// Build a minimal triangle mesh (3 verts, 1 tri). `has_suit` is set to true
477    /// so the export functions do not reject it.
478    fn tri_mesh(y_offset: f32) -> MeshBuffers {
479        MeshBuffers::from_morph(MB {
480            positions: vec![
481                [0.0, y_offset, 0.0],
482                [1.0, y_offset, 0.0],
483                [0.0, y_offset + 1.0, 0.0],
484            ],
485            normals: vec![[0.0, 0.0, 1.0]; 3],
486            uvs: vec![[0.0, 0.0]; 3],
487            indices: vec![0, 1, 2],
488            has_suit: true,
489        })
490    }
491
492    // ── Transform tests ───────────────────────────────────────────────────────
493
494    #[test]
495    fn transform_identity_is_identity() {
496        let t = Transform::identity();
497        assert!(t.is_identity(), "identity matrix must report is_identity()");
498    }
499
500    #[test]
501    fn transform_translation_correct_matrix() {
502        let t = Transform::translation(3.0, 5.0, 7.0);
503        // In column-major layout the translation is at indices 12, 13, 14.
504        assert_eq!(t.matrix[12], 3.0);
505        assert_eq!(t.matrix[13], 5.0);
506        assert_eq!(t.matrix[14], 7.0);
507        // Diagonal should be [1,1,1,1]
508        assert_eq!(t.matrix[0], 1.0);
509        assert_eq!(t.matrix[5], 1.0);
510        assert_eq!(t.matrix[10], 1.0);
511        assert_eq!(t.matrix[15], 1.0);
512    }
513
514    #[test]
515    fn transform_compose_identity_unchanged() {
516        let t = Transform::translation(1.0, 2.0, 3.0);
517        let id = Transform::identity();
518        let composed = t.compose(&id);
519        // composing with identity should give the same matrix
520        for (a, b) in composed.matrix.iter().zip(t.matrix.iter()) {
521            assert!(
522                (a - b).abs() < 1e-6,
523                "compose with identity changed the matrix"
524            );
525        }
526    }
527
528    #[test]
529    fn transform_scale_compose_translation() {
530        // scale(2,2,2) * translation(1,0,0) should scale the translation too.
531        let s = Transform::scale(2.0, 2.0, 2.0);
532        let tr = Transform::translation(1.0, 0.0, 0.0);
533        let composed = s.compose(&tr);
534        // column 3 (indices 12-14) should be [2, 0, 0] because scale * [1,0,0,1]
535        assert!((composed.matrix[12] - 2.0).abs() < 1e-6);
536        assert!((composed.matrix[13]).abs() < 1e-6);
537        assert!((composed.matrix[14]).abs() < 1e-6);
538    }
539
540    // ── SceneNode tests ───────────────────────────────────────────────────────
541
542    #[test]
543    fn scene_node_no_children_count_one() {
544        let node = SceneNode::new("root");
545        assert_eq!(node.node_count(), 1);
546    }
547
548    #[test]
549    fn scene_node_with_children_count_correct() {
550        let mut root = SceneNode::new("root");
551        root.add_child(SceneNode::new("child_a"));
552        let mut child_b = SceneNode::new("child_b");
553        child_b.add_child(SceneNode::new("grandchild"));
554        root.add_child(child_b);
555        // root + child_a + child_b + grandchild = 4
556        assert_eq!(root.node_count(), 4);
557    }
558
559    #[test]
560    fn scene_graph_mesh_count_correct() {
561        let mut graph = SceneGraph::new("scene");
562        graph
563            .root
564            .add_child(SceneNode::new("body").with_mesh(tri_mesh(0.0)));
565        graph
566            .root
567            .add_child(SceneNode::new("clothing").with_mesh(tri_mesh(1.0)));
568        graph.root.add_child(SceneNode::new("empty_node"));
569        assert_eq!(graph.mesh_count(), 2);
570    }
571
572    #[test]
573    fn all_names_depth_first_order() {
574        let mut root = SceneNode::new("root");
575        let mut child_a = SceneNode::new("child_a");
576        child_a.add_child(SceneNode::new("grandchild_a1"));
577        child_a.add_child(SceneNode::new("grandchild_a2"));
578        root.add_child(child_a);
579        root.add_child(SceneNode::new("child_b"));
580
581        let names = root.all_names();
582        assert_eq!(
583            names,
584            vec![
585                "root",
586                "child_a",
587                "grandchild_a1",
588                "grandchild_a2",
589                "child_b"
590            ]
591        );
592    }
593
594    // ── Export tests ──────────────────────────────────────────────────────────
595
596    #[test]
597    fn export_scene_graph_creates_file() {
598        let path = std::path::Path::new("/tmp/test_scene_graph_creates.glb");
599        let graph = SceneGraph::new("test");
600        export_scene_graph_glb(&graph, path).expect("export must succeed");
601        assert!(path.exists(), "GLB file must be created");
602        std::fs::remove_file(path).ok();
603    }
604
605    #[test]
606    fn export_scene_graph_valid_glb_header() {
607        let path = std::path::Path::new("/tmp/test_scene_graph_header.glb");
608        let mut graph = SceneGraph::new("header_test");
609        graph
610            .root
611            .add_child(SceneNode::new("body").with_mesh(tri_mesh(0.0)));
612        export_scene_graph_glb(&graph, path).expect("export must succeed");
613
614        let bytes = std::fs::read(path).expect("should succeed");
615        assert!(bytes.len() >= 12, "GLB must have at least 12 bytes");
616        // Magic "glTF" in LE = [0x67, 0x6C, 0x54, 0x46]
617        assert_eq!(
618            &bytes[0..4],
619            &[0x67u8, 0x6Cu8, 0x54u8, 0x46u8],
620            "GLB magic must be glTF"
621        );
622        // Version = 2 (little-endian u32)
623        let version = u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
624        assert_eq!(version, 2, "GLB version must be 2");
625        std::fs::remove_file(path).ok();
626    }
627
628    #[test]
629    fn export_empty_mesh_nodes_still_creates_file() {
630        let path = std::path::Path::new("/tmp/test_scene_graph_empty_mesh.glb");
631        // Graph with nodes but no meshes
632        let mut graph = SceneGraph::new("empty_mesh_test");
633        graph.root.add_child(SceneNode::new("no_mesh_child"));
634        export_scene_graph_glb(&graph, path).expect("export must succeed even without meshes");
635        assert!(path.exists(), "GLB file must be created");
636        let bytes = std::fs::read(path).expect("should succeed");
637        assert!(bytes.len() >= 12);
638        std::fs::remove_file(path).ok();
639    }
640
641    #[test]
642    fn export_two_mesh_nodes() {
643        let path = std::path::Path::new("/tmp/test_scene_graph_two_meshes.glb");
644        let mut graph = SceneGraph::new("two_mesh_scene");
645        graph.root.add_child(
646            SceneNode::new("body")
647                .with_mesh(tri_mesh(0.0))
648                .with_transform(Transform::translation(0.0, 0.0, 0.0)),
649        );
650        graph.root.add_child(
651            SceneNode::new("hat")
652                .with_mesh(tri_mesh(2.0))
653                .with_transform(Transform::translation(0.0, 1.8, 0.0)),
654        );
655        export_scene_graph_glb(&graph, path).expect("export must succeed");
656        assert!(path.exists());
657        let bytes = std::fs::read(path).expect("should succeed");
658        // Should be bigger than a single-mesh export since there are two meshes.
659        assert!(bytes.len() > 12);
660        std::fs::remove_file(path).ok();
661    }
662
663    #[test]
664    fn export_nested_hierarchy() {
665        let path = std::path::Path::new("/tmp/test_scene_graph_nested.glb");
666        let mut graph = SceneGraph::new("nested");
667        let mut torso = SceneNode::new("torso").with_mesh(tri_mesh(0.0));
668        let head = SceneNode::new("head")
669            .with_mesh(tri_mesh(1.5))
670            .with_transform(Transform::translation(0.0, 1.5, 0.0));
671        torso.add_child(head);
672        graph.root.add_child(torso);
673        assert_eq!(graph.mesh_count(), 2);
674        export_scene_graph_glb(&graph, path).expect("nested export must succeed");
675        assert!(path.exists());
676        std::fs::remove_file(path).ok();
677    }
678}