kengaai_model_loader/
lib.rs

1use anyhow::Result;
2use glam::{Vec3, Vec2, Vec4};
3// use std::collections::HashMap;
4use std::path::Path;
5
6/// 3D Vertex structure
7#[repr(C)]
8#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
9pub struct Vertex {
10    pub position: Vec3,
11    pub normal: Vec3,
12    pub tex_coords: Vec2,
13    pub tangent: Vec4,
14    pub joints: [u32; 4], // Changed to u32 for compatibility with wgpu
15    pub weights: [f32; 4],
16}
17
18impl Default for Vertex {
19    fn default() -> Self {
20        Self {
21            position: Vec3::ZERO,
22            normal: Vec3::Z,
23            tex_coords: Vec2::ZERO,
24            tangent: Vec4::new(1.0, 0.0, 0.0, 1.0),
25            joints: [0, 0, 0, 0],
26            weights: [1.0, 0.0, 0.0, 0.0],
27        }
28    }
29}
30
31/// Image data structure
32#[derive(Debug, Clone)]
33pub struct ImageData {
34    pub width: u32,
35    pub height: u32,
36    pub format: gltf::image::Format,
37    pub pixels: Vec<u8>,
38}
39
40/// Material structure
41#[derive(Debug, Clone)]
42pub struct Material {
43    pub name: String,
44    pub base_color_factor: [f32; 4],
45    pub metallic_factor: f32,
46    pub roughness_factor: f32,
47    pub base_color_texture_index: Option<usize>,
48    pub metallic_roughness_texture_index: Option<usize>,
49    pub normal_texture_index: Option<usize>,
50    pub occlusion_texture_index: Option<usize>,
51}
52
53impl Default for Material {
54    fn default() -> Self {
55        Self {
56            name: "default_material".to_string(),
57            base_color_factor: [1.0, 1.0, 1.0, 1.0],
58            metallic_factor: 0.0,
59            roughness_factor: 1.0,
60            base_color_texture_index: None,
61            metallic_roughness_texture_index: None,
62            normal_texture_index: None,
63            occlusion_texture_index: None,
64        }
65    }
66}
67
68/// 3D Mesh structure
69#[derive(Debug, Clone)]
70pub struct Mesh {
71    pub name: String,
72    pub vertices: Vec<Vertex>,
73    pub indices: Vec<u32>,
74    pub material_index: Option<usize>,
75}
76
77/// 3D Model structure
78#[derive(Debug, Clone)]
79pub struct Model {
80    pub name: String,
81    pub meshes: Vec<Mesh>,
82    pub materials: Vec<Material>,
83    pub images: Vec<ImageData>,
84}
85
86/// glTF file loader
87pub struct GltfLoader;
88
89impl GltfLoader {
90    pub fn load<P: AsRef<Path>>(path: P) -> Result<Model> {
91        let path = path.as_ref();
92        let (doc, buffers, images_data) = gltf::import(path)?;
93
94        let images = images_data.into_iter().map(|mut img| {
95            if img.format == gltf::image::Format::R8G8B8 {
96                let mut new_pixels = Vec::with_capacity(img.pixels.len() / 3 * 4);
97                for chunk in img.pixels.chunks_exact(3) {
98                    new_pixels.extend_from_slice(chunk);
99                    new_pixels.push(255); // Add alpha channel
100                }
101                img.pixels = new_pixels;
102                img.format = gltf::image::Format::R8G8B8A8;
103            }
104            ImageData {
105                width: img.width,
106                height: img.height,
107                format: img.format,
108                pixels: img.pixels,
109            }
110        }).collect();
111
112        let materials = doc.materials().map(|mat| {
113            let pbr = mat.pbr_metallic_roughness();
114            Material {
115                name: mat.name().unwrap_or("").to_string(),
116                base_color_factor: pbr.base_color_factor(),
117                metallic_factor: pbr.metallic_factor(),
118                roughness_factor: pbr.roughness_factor(),
119                base_color_texture_index: pbr.base_color_texture().map(|info| info.texture().index()),
120                metallic_roughness_texture_index: pbr.metallic_roughness_texture().map(|info| info.texture().index()),
121                normal_texture_index: mat.normal_texture().map(|info| info.texture().index()),
122                occlusion_texture_index: mat.occlusion_texture().map(|info| info.texture().index()),
123            }
124        }).collect();
125
126        let mut model = Model {
127            name: path.file_stem().and_then(|s| s.to_str()).unwrap_or("unnamed").to_string(),
128            meshes: Vec::new(),
129            materials,
130            images,
131        };
132
133        for scene in doc.scenes() {
134            for node in scene.nodes() {
135                Self::process_node(node, &buffers, &mut model)?;
136            }
137        }
138
139        Ok(model)
140    }
141
142    fn process_node(node: gltf::Node, buffers: &[gltf::buffer::Data], model: &mut Model) -> Result<()> {
143        if let Some(mesh) = node.mesh() {
144            for primitive in mesh.primitives() {
145                let reader = primitive.reader(|buffer| Some(&buffers[buffer.index()]));
146
147                let positions: Vec<Vec3> = if let Some(iter) = reader.read_positions() {
148                    iter.map(Vec3::from).collect()
149                } else {
150                    continue; // Skip primitives without positions
151                };
152
153                let normals: Vec<Vec3> = if let Some(iter) = reader.read_normals() {
154                    iter.map(Vec3::from).collect()
155                } else {
156                    vec![Vec3::Z; positions.len()] // Generate dummy normals if not present
157                };
158
159                let tex_coords: Vec<Vec2> = if let Some(iter) = reader.read_tex_coords(0) {
160                    iter.into_f32().map(Vec2::from).collect()
161                } else {
162                    vec![Vec2::ZERO; positions.len()] // Generate dummy tex_coords if not present
163                };
164
165                let tangents: Vec<Vec4> = if let Some(iter) = reader.read_tangents() {
166                    iter.map(Vec4::from).collect()
167                } else {
168                    // If tangents are not provided, we should calculate them.
169                    // For now, we'll just use a placeholder.
170                    // TODO: Calculate tangents using a library like `mikktspace`.
171                    vec![Vec4::new(1.0, 0.0, 0.0, 1.0); positions.len()]
172                };
173
174                let joints: Vec<[u32; 4]> = if let Some(iter) = reader.read_joints(0) {
175                    // Convert u16 joints to u32
176                    iter.into_u16().map(|j| [j[0] as u32, j[1] as u32, j[2] as u32, j[3] as u32]).collect()
177                } else {
178                    vec![[0, 0, 0, 0]; positions.len()]
179                };
180
181                let weights: Vec<[f32; 4]> = if let Some(iter) = reader.read_weights(0) {
182                    iter.into_f32().collect()
183                } else {
184                    vec![[1.0, 0.0, 0.0, 0.0]; positions.len()]
185                };
186
187                let vertices: Vec<Vertex> = positions.iter().enumerate().map(|(i, &pos)| Vertex {
188                    position: pos,
189                    normal: normals[i],
190                    tex_coords: tex_coords[i],
191                    tangent: tangents[i],
192                    joints: joints[i],
193                    weights: weights[i],
194                }).collect();
195
196                let indices: Vec<u32> = if let Some(iter) = reader.read_indices() {
197                    iter.into_u32().collect()
198                } else {
199                    (0..vertices.len() as u32).collect()
200                };
201
202                let material_index = primitive.material().index();
203
204                model.meshes.push(Mesh {
205                    name: mesh.name().unwrap_or("unnamed_mesh").to_string(),
206                    vertices,
207                    indices,
208                    material_index,
209                });
210            }
211        }
212
213        for child in node.children() {
214            Self::process_node(child, buffers, model)?;
215        }
216
217        Ok(())
218    }
219}
220
221/*
222/// OBJ file loader
223pub struct ObjLoader;
224
225impl ObjLoader {
226    /// Load OBJ file from path
227    pub fn load<P: AsRef<Path>>(path: P) -> Result<Model> {
228        let path = path.as_ref();
229        let content = fs::read_to_string(path)
230            .with_context(|| format!("Failed to read OBJ file: {}", path.display()))?;
231
232        let mut model = Model {
233            name: path.file_stem()
234                .and_then(|s| s.to_str())
235                .unwrap_or("unnamed")
236                .to_string(),
237            meshes: Vec::new(),
238            materials: HashMap::new(),
239            images: Vec::new(),
240        };
241
242        Self::parse_obj(&content, &mut model, path)?;
243
244        Ok(model)
245    }
246
247    /// Parse OBJ file content
248    fn parse_obj(content: &str, model: &mut Model, obj_path: &Path) -> Result<()> {
249        let mut positions: Vec<Vec3> = Vec::new();
250        let mut normals: Vec<Vec3> = Vec::new();
251        let mut tex_coords: Vec<Vec2> = Vec::new();
252        let mut current_mesh: Option<Mesh> = None;
253        let mut current_material = None;
254
255        for (_line_num, line) in content.lines().enumerate() {
256            let line = line.trim();
257
258            if line.is_empty() || line.starts_with('#') {
259                continue;
260            }
261
262            let parts: Vec<&str> = line.split_whitespace().collect();
263            if parts.is_empty() {
264                continue;
265            }
266
267            match parts[0] {
268                "mtllib" => {
269                    // Load material library
270                    if parts.len() > 1 {
271                        let mtl_path = obj_path.parent()
272                            .unwrap_or(Path::new("."))
273                            .join(parts[1]);
274                        Self::load_mtl(&mtl_path, &mut model.materials)?;
275                    }
276                }
277
278                "usemtl" => {
279                    // Use material
280                    if parts.len() > 1 {
281                        current_material = Some(parts[1].to_string());
282                    }
283                }
284
285                "o" => {
286                    // New object
287                    if let Some(mesh) = current_mesh.take() {
288                        model.meshes.push(mesh);
289                    }
290
291                    let name = if parts.len() > 1 {
292                        parts[1].to_string()
293                    } else {
294                        format!("object_{}", model.meshes.len())
295                    };
296
297                    current_mesh = Some(Mesh {
298                        name,
299                        vertices: Vec::new(),
300                        indices: Vec::new(),
301                        material_index: None,
302                    });
303                }
304
305                "v" => {
306                    // Vertex position
307                    if parts.len() >= 4 {
308                        let x = parts[1].parse::<f32>()?;
309                        let y = parts[2].parse::<f32>()?;
310                        let z = parts[3].parse::<f32>()?;
311                        positions.push(Vec3::new(x, y, z));
312                    }
313                }
314
315                "vn" => {
316                    // Vertex normal
317                    if parts.len() >= 4 {
318                        let x = parts[1].parse::<f32>()?;
319                        let y = parts[2].parse::<f32>()?;
320                        let z = parts[3].parse::<f32>()?;
321                        normals.push(Vec3::new(x, y, z));
322                    }
323                }
324
325                "vt" => {
326                    // Texture coordinates
327                    if parts.len() >= 3 {
328                        let u = parts[1].parse::<f32>()?;
329                        let v = parts[2].parse::<f32>()?;
330                        tex_coords.push(Vec2::new(u, v));
331                    }
332                }
333
334                "f" => {
335                    // Face
336                    if let Some(ref mut mesh) = current_mesh {
337                        Self::parse_face(&parts[1..], mesh, &positions, &normals, &tex_coords)?;
338                    }
339                }
340
341                _ => {
342                    // Ignore other directives for now
343                }
344            }
345        }
346
347        // Add final mesh
348        if let Some(mesh) = current_mesh.take() {
349            model.meshes.push(mesh);
350        }
351
352        // Assign materials to meshes
353        for mesh in &mut model.meshes {
354            if let Some(ref mat_name) = current_material {
355                if let Some(material) = model.materials.get(mat_name) {
356                    mesh.material = Some(material.clone());
357                }
358            }
359        }
360
361        Ok(())
362    }
363
364    /// Parse face definition
365    fn parse_face(
366        face_parts: &[&str],
367        mesh: &mut Mesh,
368        positions: &[Vec3],
369        normals: &[Vec3],
370        tex_coords: &[Vec2],
371    ) -> Result<()> {
372        if face_parts.len() < 3 {
373            return Ok(()); // Skip invalid faces
374        }
375
376        let mut face_indices = Vec::new();
377
378        for part in face_parts {
379            let indices: Vec<&str> = part.split('/').collect();
380
381            let vertex_idx = if !indices[0].is_empty() {
382                let idx = indices[0].parse::<i32>()?;
383                if idx < 0 {
384                    (positions.len() as i32 + idx) as usize
385                } else {
386                    (idx - 1) as usize
387                }
388            } else {
389                0
390            };
391
392            let tex_idx = if indices.len() > 1 && !indices[1].is_empty() {
393                let idx = indices[1].parse::<i32>()?;
394                if idx < 0 {
395                    (tex_coords.len() as i32 + idx) as usize
396                } else {
397                    (idx - 1) as usize
398                }
399            } else {
400                0
401            };
402
403            let normal_idx = if indices.len() > 2 && !indices[2].is_empty() {
404                let idx = indices[2].parse::<i32>()?;
405                if idx < 0 {
406                    (normals.len() as i32 + idx) as usize
407                } else {
408                    (idx - 1) as usize
409                }
410            } else {
411                0
412            };
413
414            // Create vertex
415            let vertex = Vertex {
416                position: positions.get(vertex_idx).copied().unwrap_or(Vec3::ZERO),
417                tex_coords: tex_coords.get(tex_idx).copied().unwrap_or(Vec2::ZERO),
418                normal: normals.get(normal_idx).copied().unwrap_or(Vec3::Z),
419                joints: [0, 0, 0, 0],
420                weights: [1.0, 0.0, 0.0, 0.0],
421            };
422
423            let vertex_index = mesh.vertices.len() as u32;
424            mesh.vertices.push(vertex);
425            face_indices.push(vertex_index);
426        }
427
428        // Triangulate face (simple fan triangulation)
429        if face_indices.len() >= 3 {
430            for i in 1..face_indices.len() - 1 {
431                mesh.indices.push(face_indices[0]);
432                mesh.indices.push(face_indices[i]);
433                mesh.indices.push(face_indices[i + 1]);
434            }
435        }
436
437        Ok(())
438    }
439
440    /// Load MTL material file
441    fn load_mtl(path: &Path, materials: &mut HashMap<String, Material>) -> Result<()> {
442        if !path.exists() {
443            return Ok(()); // MTL file is optional
444        }
445
446        let content = fs::read_to_string(path)
447            .with_context(|| format!("Failed to read MTL file: {}", path.display()))?;
448
449        let mut current_material: Option<Material> = None;
450
451        for line in content.lines() {
452            let line = line.trim();
453
454            if line.is_empty() || line.starts_with('#') {
455                continue;
456            }
457
458            let parts: Vec<&str> = line.split_whitespace().collect();
459            if parts.is_empty() {
460                continue;
461            }
462
463            match parts[0] {
464                "newmtl" => {
465                    // Save previous material
466                    if let Some(mat) = current_material.take() {
467                        materials.insert(mat.name.clone(), mat);
468                    }
469
470                    // Create new material
471                    let name = parts.get(1).unwrap_or(&"unnamed").to_string();
472                    current_material = Some(Material {
473                        name,
474                        diffuse_color: Vec3::ONE,
475                        specular_color: Vec3::ZERO,
476                        shininess: 1.0,
477                        diffuse_texture_index: None,
478                    });
479                }
480
481                "Kd" => {
482                    // Diffuse color
483                    if let Some(ref mut mat) = current_material {
484                        if parts.len() >= 4 {
485                            mat.diffuse_color = Vec3::new(
486                                parts[1].parse().unwrap_or(1.0),
487                                parts[2].parse().unwrap_or(1.0),
488                                parts[3].parse().unwrap_or(1.0),
489                            );
490                        }
491                    }
492                }
493
494                "Ks" => {
495                    // Specular color
496                    if let Some(ref mut mat) = current_material {
497                        if parts.len() >= 4 {
498                            mat.specular_color = Vec3::new(
499                                parts[1].parse().unwrap_or(0.0),
500                                parts[2].parse().unwrap_or(0.0),
501                                parts[3].parse().unwrap_or(0.0),
502                            );
503                        }
504                    }
505                }
506
507                "Ns" => {
508                    // Shininess
509                    if let Some(ref mut mat) = current_material {
510                        if parts.len() >= 2 {
511                            mat.shininess = parts[1].parse().unwrap_or(1.0);
512                        }
513                    }
514                }
515
516                "map_Kd" => {
517                    // Diffuse texture
518                    if let Some(ref mut mat) = current_material {
519                        if parts.len() >= 2 {
520                            mat.diffuse_texture_index = Some(parts[1].to_string());
521                        }
522                    }
523                }
524
525                _ => {
526                    // Ignore other material properties
527                }
528            }
529        }
530
531        // Save final material
532        if let Some(mat) = current_material.take() {
533            materials.insert(mat.name.clone(), mat);
534        }
535
536        Ok(())
537    }
538}
539*/
540
541/// Convert model to renderable format
542pub fn model_to_render_data(model: &Model) -> (Vec<Vertex>, Vec<u32>) {
543    let mut all_vertices = Vec::new();
544    let mut all_indices = Vec::new();
545    let mut vertex_offset = 0;
546
547    for mesh in &model.meshes {
548        // Add vertices
549        all_vertices.extend_from_slice(&mesh.vertices);
550
551        // Add indices with offset
552        for &index in &mesh.indices {
553            all_indices.push(index + vertex_offset);
554        }
555
556        vertex_offset += mesh.vertices.len() as u32;
557    }
558
559    (all_vertices, all_indices)
560}
561
562/*
563#[cfg(test)]
564mod tests {
565    use super::*;
566    use std::io::Write;
567    use tempfile::NamedTempFile;
568
569    #[test]
570    fn test_simple_obj_loading() {
571        let obj_content = r#"
572# Simple cube
573o cube
574v -1.0 -1.0 -1.0
575v  1.0 -1.0 -1.0
576v  1.0  1.0 -1.0
577v -1.0  1.0 -1.0
578f 1 2 3 4
579"#;
580
581        let mut temp_file = NamedTempFile::new().unwrap();
582        temp_file.write_all(obj_content.as_bytes()).unwrap();
583        temp_file.flush().unwrap();
584
585        let path = temp_file.path();
586        let expected = path
587            .file_stem()
588            .and_then(|s| s.to_str())
589            .unwrap_or("")
590            .to_string();
591        let model = ObjLoader::load(path).unwrap();
592
593        assert_eq!(model.name, expected);
594        assert_eq!(model.meshes.len(), 1);
595        assert_eq!(model.meshes[0].vertices.len(), 4);
596        assert_eq!(model.meshes[0].indices.len(), 6); // 2 triangles
597    }
598}
599*/