Skip to main content

burn_synth/
io.rs

1use std::path::{Path, PathBuf};
2
3#[derive(Clone, Debug)]
4pub enum ImageSource {
5    Path(PathBuf),
6    Bytes(Vec<u8>),
7}
8
9impl ImageSource {
10    pub fn from_path(path: impl Into<PathBuf>) -> Self {
11        Self::Path(path.into())
12    }
13
14    pub fn from_bytes(bytes: impl Into<Vec<u8>>) -> Self {
15        Self::Bytes(bytes.into())
16    }
17
18    pub fn path(&self) -> Option<&Path> {
19        match self {
20            Self::Path(path) => Some(path.as_path()),
21            Self::Bytes(_) => None,
22        }
23    }
24
25    pub fn bytes(&self) -> Option<&[u8]> {
26        match self {
27            Self::Path(_) => None,
28            Self::Bytes(bytes) => Some(bytes),
29        }
30    }
31}
32
33#[derive(Clone, Debug, PartialEq, Eq)]
34pub struct TextPrompt(pub String);
35
36impl From<String> for TextPrompt {
37    fn from(value: String) -> Self {
38        Self(value)
39    }
40}
41
42impl From<&str> for TextPrompt {
43    fn from(value: &str) -> Self {
44        Self(value.to_string())
45    }
46}
47
48impl TextPrompt {
49    pub fn as_str(&self) -> &str {
50        &self.0
51    }
52}
53
54#[cfg(any(feature = "runtime", feature = "wasm-api"))]
55use std::borrow::Cow;
56#[cfg(any(feature = "runtime", feature = "wasm-api"))]
57use std::fs;
58#[cfg(any(feature = "runtime", feature = "wasm-api"))]
59use std::io;
60
61#[cfg(any(feature = "runtime", feature = "wasm-api"))]
62use crate::mesh::{Mesh, MeshTexture};
63#[cfg(any(feature = "runtime", feature = "wasm-api"))]
64use image::ImageEncoder;
65#[cfg(any(feature = "runtime", feature = "wasm-api"))]
66use serde_json::{Value, json};
67
68#[cfg(any(feature = "runtime", feature = "wasm-api"))]
69#[derive(Clone, Debug)]
70struct MeshBinaryLayout {
71    buffer: Vec<u8>,
72    positions_byte_offset: usize,
73    positions_byte_length: usize,
74    indices_byte_offset: usize,
75    indices_byte_length: usize,
76    uvs_byte_offset: Option<usize>,
77    uvs_byte_length: Option<usize>,
78    base_color_image_view: Option<(usize, usize)>,
79    metallic_roughness_image_view: Option<(usize, usize)>,
80    normal_image_view: Option<(usize, usize)>,
81    emissive_image_view: Option<(usize, usize)>,
82    occlusion_image_view: Option<(usize, usize)>,
83    min: [f32; 3],
84    max: [f32; 3],
85}
86
87#[cfg(any(feature = "runtime", feature = "wasm-api"))]
88pub fn write_glb_mesh(path: &Path, mesh: &Mesh) -> Result<(), String> {
89    ensure_parent_dir(path).map_err(|err| err.to_string())?;
90    let bytes = mesh_to_glb_bytes(mesh)?;
91    fs::write(path, bytes).map_err(|err| format!("failed to write {}: {err}", path.display()))
92}
93
94#[cfg(any(feature = "runtime", feature = "wasm-api"))]
95pub fn mesh_to_glb_bytes(mesh: &Mesh) -> Result<Vec<u8>, String> {
96    let layout = build_mesh_binary_layout(mesh)?;
97    let gltf = gltf_json(mesh, &layout);
98    let json_bytes = serde_json::to_vec(&gltf)
99        .map_err(|err| format!("failed to serialize glTF json chunk: {err}"))?;
100    gltf::Glb {
101        header: gltf::binary::Header {
102            magic: *b"glTF",
103            version: 2,
104            length: 0,
105        },
106        json: Cow::Owned(json_bytes),
107        bin: Some(Cow::Owned(layout.buffer)),
108    }
109    .to_vec()
110    .map_err(|err| format!("failed to build GLB: {err}"))
111}
112
113#[cfg(any(feature = "runtime", feature = "wasm-api"))]
114fn build_mesh_binary_layout(mesh: &Mesh) -> Result<MeshBinaryLayout, String> {
115    if mesh.vertices.is_empty() {
116        return Err("cannot export empty mesh".to_string());
117    }
118
119    let mut min = [f32::INFINITY; 3];
120    let mut max = [f32::NEG_INFINITY; 3];
121    for vertex in &mesh.vertices {
122        for axis in 0..3 {
123            min[axis] = min[axis].min(vertex[axis]);
124            max[axis] = max[axis].max(vertex[axis]);
125        }
126    }
127
128    let mut buffer = Vec::with_capacity(mesh.vertices.len() * 12 + mesh.faces.len() * 12 + 8192);
129    let positions_byte_offset = buffer.len();
130    for vertex in &mesh.vertices {
131        for component in vertex {
132            buffer.extend_from_slice(&component.to_le_bytes());
133        }
134    }
135    let positions_byte_length = buffer.len() - positions_byte_offset;
136
137    let mut uvs_byte_offset = None;
138    let mut uvs_byte_length = None;
139    if mesh.uvs.len() == mesh.vertices.len() && !mesh.uvs.is_empty() {
140        pad_buffer_4(&mut buffer);
141        let offset = buffer.len();
142        for uv in &mesh.uvs {
143            buffer.extend_from_slice(&uv[0].to_le_bytes());
144            buffer.extend_from_slice(&uv[1].to_le_bytes());
145        }
146        uvs_byte_offset = Some(offset);
147        uvs_byte_length = Some(buffer.len() - offset);
148    }
149
150    pad_buffer_4(&mut buffer);
151    let indices_byte_offset = buffer.len();
152    for face in &mesh.faces {
153        for index in face {
154            buffer.extend_from_slice(&index.to_le_bytes());
155        }
156    }
157    let indices_byte_length = buffer.len() - indices_byte_offset;
158
159    let mut base_color_image_view = None;
160    let mut metallic_roughness_image_view = None;
161    let mut normal_image_view = None;
162    let mut emissive_image_view = None;
163    let mut occlusion_image_view = None;
164    if let Some(pbr) = mesh.pbr_textures.as_ref() {
165        let base_png = encode_rgba_texture_png(&pbr.base_color)?;
166        let mr_png = encode_rgba_texture_png(&pbr.metallic_roughness)?;
167        pad_buffer_4(&mut buffer);
168        let base_offset = buffer.len();
169        buffer.extend_from_slice(base_png.as_slice());
170        base_color_image_view = Some((base_offset, base_png.len()));
171        pad_buffer_4(&mut buffer);
172        let mr_offset = buffer.len();
173        buffer.extend_from_slice(mr_png.as_slice());
174        metallic_roughness_image_view = Some((mr_offset, mr_png.len()));
175
176        if let Some(normal) = pbr.normal.as_ref() {
177            let png = encode_rgba_texture_png(normal)?;
178            pad_buffer_4(&mut buffer);
179            let offset = buffer.len();
180            buffer.extend_from_slice(png.as_slice());
181            normal_image_view = Some((offset, png.len()));
182        }
183        if let Some(emissive) = pbr.emissive.as_ref() {
184            let png = encode_rgba_texture_png(emissive)?;
185            pad_buffer_4(&mut buffer);
186            let offset = buffer.len();
187            buffer.extend_from_slice(png.as_slice());
188            emissive_image_view = Some((offset, png.len()));
189        }
190        if let Some(occlusion) = pbr.occlusion.as_ref() {
191            let png = encode_rgba_texture_png(occlusion)?;
192            pad_buffer_4(&mut buffer);
193            let offset = buffer.len();
194            buffer.extend_from_slice(png.as_slice());
195            occlusion_image_view = Some((offset, png.len()));
196        }
197    }
198
199    Ok(MeshBinaryLayout {
200        buffer,
201        positions_byte_offset,
202        positions_byte_length,
203        indices_byte_offset,
204        indices_byte_length,
205        uvs_byte_offset,
206        uvs_byte_length,
207        base_color_image_view,
208        metallic_roughness_image_view,
209        normal_image_view,
210        emissive_image_view,
211        occlusion_image_view,
212        min,
213        max,
214    })
215}
216
217#[cfg(any(feature = "runtime", feature = "wasm-api"))]
218fn gltf_json(mesh: &Mesh, layout: &MeshBinaryLayout) -> Value {
219    let mut primitive = json!({
220        "attributes": {
221            "POSITION": 0
222        },
223        "indices": 1,
224        "mode": 4
225    });
226    if mesh.uvs.len() == mesh.vertices.len() && !mesh.uvs.is_empty() {
227        primitive["attributes"]["TEXCOORD_0"] = json!(2);
228    }
229
230    let buffers = vec![json!({
231        "byteLength": layout.buffer.len(),
232    })];
233
234    let mut buffer_views = Vec::new();
235    buffer_views.push(json!({
236        "buffer": 0,
237        "byteOffset": layout.positions_byte_offset,
238        "byteLength": layout.positions_byte_length,
239        "target": 34962
240    }));
241    buffer_views.push(json!({
242        "buffer": 0,
243        "byteOffset": layout.indices_byte_offset,
244        "byteLength": layout.indices_byte_length,
245        "target": 34963
246    }));
247    if let (Some(uv_offset), Some(uv_len)) = (layout.uvs_byte_offset, layout.uvs_byte_length) {
248        buffer_views.push(json!({
249            "buffer": 0,
250            "byteOffset": uv_offset,
251            "byteLength": uv_len,
252            "target": 34962
253        }));
254    }
255
256    let mut accessors = Vec::new();
257    accessors.push(json!({
258        "bufferView": 0,
259        "componentType": 5126,
260        "count": mesh.vertices.len(),
261        "type": "VEC3",
262        "min": layout.min,
263        "max": layout.max
264    }));
265    accessors.push(json!({
266        "bufferView": 1,
267        "componentType": 5125,
268        "count": mesh.faces.len() * 3,
269        "type": "SCALAR"
270    }));
271    if mesh.uvs.len() == mesh.vertices.len() && !mesh.uvs.is_empty() {
272        accessors.push(json!({
273            "bufferView": 2,
274            "componentType": 5126,
275            "count": mesh.uvs.len(),
276            "type": "VEC2"
277        }));
278    }
279
280    let mut images = Vec::new();
281    let mut textures = Vec::new();
282    let mut materials = Vec::new();
283    let mut pbr_mr = json!({});
284    let mut push_texture_image = |byte_offset: usize, byte_length: usize| -> usize {
285        let view_index = buffer_views.len();
286        buffer_views.push(json!({
287            "buffer": 0,
288            "byteOffset": byte_offset,
289            "byteLength": byte_length
290        }));
291        let image_index = images.len();
292        images.push(json!({
293            "bufferView": view_index,
294            "mimeType": "image/png"
295        }));
296        let texture_index = textures.len();
297        textures.push(json!({ "source": image_index }));
298        texture_index
299    };
300
301    if let Some(material) = mesh.material {
302        pbr_mr = json!({
303            "baseColorFactor": [
304                material.base_color[0],
305                material.base_color[1],
306                material.base_color[2],
307                material.alpha.clamp(0.0, 1.0)
308            ],
309            "metallicFactor": material.metallic.clamp(0.0, 1.0),
310            "roughnessFactor": material.roughness.clamp(0.0, 1.0)
311        });
312    }
313    if let Some((base_offset, base_len)) = layout.base_color_image_view {
314        let texture_index = push_texture_image(base_offset, base_len);
315        pbr_mr["baseColorTexture"] = json!({ "index": texture_index });
316    }
317    if let Some((mr_offset, mr_len)) = layout.metallic_roughness_image_view {
318        let texture_index = push_texture_image(mr_offset, mr_len);
319        pbr_mr["metallicRoughnessTexture"] = json!({ "index": texture_index });
320    }
321
322    if mesh.material.is_some() || mesh.pbr_textures.is_some() {
323        let alpha = mesh
324            .material
325            .map(|value| value.alpha)
326            .unwrap_or(1.0)
327            .clamp(0.0, 1.0);
328        let material_index = materials.len();
329        let mut material = json!({
330            "pbrMetallicRoughness": pbr_mr,
331            "alphaMode": if alpha < 0.995 { "BLEND" } else { "OPAQUE" },
332            "doubleSided": true
333        });
334        if let Some((normal_offset, normal_len)) = layout.normal_image_view {
335            let texture_index = push_texture_image(normal_offset, normal_len);
336            material["normalTexture"] = json!({ "index": texture_index });
337        }
338        if let Some((emissive_offset, emissive_len)) = layout.emissive_image_view {
339            let texture_index = push_texture_image(emissive_offset, emissive_len);
340            material["emissiveTexture"] = json!({ "index": texture_index });
341            material["emissiveFactor"] = json!([1.0, 1.0, 1.0]);
342        }
343        if let Some((occlusion_offset, occlusion_len)) = layout.occlusion_image_view {
344            let texture_index = push_texture_image(occlusion_offset, occlusion_len);
345            material["occlusionTexture"] = json!({ "index": texture_index });
346        }
347        materials.push(material);
348        primitive["material"] = json!(material_index);
349    }
350
351    let mut gltf = json!({
352        "asset": {
353            "version": "2.0",
354            "generator": "burn_synth"
355        },
356        "scene": 0,
357        "scenes": [
358            { "nodes": [0] }
359        ],
360        "nodes": [
361            { "mesh": 0 }
362        ],
363        "meshes": [
364            {
365                "primitives": [
366                    primitive
367                ]
368            }
369        ],
370        "buffers": buffers,
371        "bufferViews": buffer_views,
372        "accessors": accessors
373    });
374    if !materials.is_empty() {
375        gltf["materials"] = Value::Array(materials);
376    }
377    if !images.is_empty() {
378        gltf["images"] = Value::Array(images);
379    }
380    if !textures.is_empty() {
381        gltf["textures"] = Value::Array(textures);
382    }
383    gltf
384}
385
386#[cfg(any(feature = "runtime", feature = "wasm-api"))]
387fn pad_buffer_4(buffer: &mut Vec<u8>) {
388    while !buffer.len().is_multiple_of(4) {
389        buffer.push(0);
390    }
391}
392
393#[cfg(any(feature = "runtime", feature = "wasm-api"))]
394fn encode_rgba_texture_png(texture: &MeshTexture) -> Result<Vec<u8>, String> {
395    let expected = texture.width as usize * texture.height as usize * 4;
396    if texture.rgba8.len() != expected {
397        return Err(format!(
398            "texture byte length mismatch: expected {}, got {}",
399            expected,
400            texture.rgba8.len()
401        ));
402    }
403    let mut out = Vec::new();
404    let encoder = image::codecs::png::PngEncoder::new(&mut out);
405    encoder
406        .write_image(
407            texture.rgba8.as_slice(),
408            texture.width,
409            texture.height,
410            image::ColorType::Rgba8.into(),
411        )
412        .map_err(|err| format!("failed to encode texture png: {err}"))?;
413    Ok(out)
414}
415
416#[cfg(any(feature = "runtime", feature = "wasm-api"))]
417fn ensure_parent_dir(path: &Path) -> io::Result<()> {
418    if let Some(parent) = path.parent()
419        && !parent.as_os_str().is_empty()
420    {
421        fs::create_dir_all(parent)?;
422    }
423    Ok(())
424}
425
426#[cfg(all(test, feature = "runtime"))]
427mod tests {
428    use super::*;
429    use crate::mesh::{Mesh, MeshMaterial, MeshPbrTextures, MeshTexture};
430
431    fn test_texture(width: u32, height: u32, rgba: [u8; 4]) -> MeshTexture {
432        let mut bytes = Vec::with_capacity(width as usize * height as usize * 4);
433        for _ in 0..(width as usize * height as usize) {
434            bytes.extend_from_slice(&rgba);
435        }
436        MeshTexture {
437            width,
438            height,
439            rgba8: bytes,
440        }
441    }
442
443    fn sample_mesh_with_pbr() -> Mesh {
444        Mesh {
445            vertices: vec![[-0.5, 0.0, 0.0], [0.5, 0.0, 0.0], [0.0, 0.8, 0.0]],
446            faces: vec![[0, 1, 2]],
447            uvs: vec![[0.0, 0.0], [1.0, 0.0], [0.5, 1.0]],
448            material: Some(MeshMaterial {
449                base_color: [1.0, 1.0, 1.0],
450                metallic: 1.0,
451                roughness: 1.0,
452                alpha: 1.0,
453            }),
454            pbr_textures: Some(MeshPbrTextures {
455                base_color: test_texture(2, 2, [200, 180, 160, 255]),
456                metallic_roughness: test_texture(2, 2, [0, 128, 64, 255]),
457                normal: None,
458                emissive: None,
459                occlusion: None,
460            }),
461        }
462    }
463
464    #[test]
465    fn glb_embeds_pbr_textures_when_present() {
466        let mesh = sample_mesh_with_pbr();
467        let bytes = mesh_to_glb_bytes(&mesh).expect("glb export");
468        let glb = gltf::Glb::from_slice(bytes.as_slice()).expect("parse glb");
469        let json: Value = serde_json::from_slice(glb.json.as_ref()).expect("parse glb json");
470
471        let materials = json["materials"].as_array().expect("materials array");
472        assert_eq!(materials.len(), 1);
473        let pbr = &materials[0]["pbrMetallicRoughness"];
474        assert!(pbr.get("baseColorTexture").is_some());
475        assert!(pbr.get("metallicRoughnessTexture").is_some());
476        assert!(
477            json["textures"]
478                .as_array()
479                .is_some_and(|value| !value.is_empty())
480        );
481        assert!(
482            json["images"]
483                .as_array()
484                .is_some_and(|value| !value.is_empty())
485        );
486    }
487
488    #[test]
489    fn glb_writes_material_when_only_textures_are_present() {
490        let mut mesh = sample_mesh_with_pbr();
491        mesh.material = None;
492
493        let bytes = mesh_to_glb_bytes(&mesh).expect("glb export");
494        let glb = gltf::Glb::from_slice(bytes.as_slice()).expect("parse glb");
495        let json: Value = serde_json::from_slice(glb.json.as_ref()).expect("parse glb json");
496
497        let materials = json["materials"].as_array().expect("materials array");
498        assert_eq!(materials.len(), 1);
499        assert_eq!(materials[0]["alphaMode"], "OPAQUE");
500        let pbr = &materials[0]["pbrMetallicRoughness"];
501        assert!(pbr.get("baseColorTexture").is_some());
502        assert!(pbr.get("metallicRoughnessTexture").is_some());
503    }
504}