Skip to main content

gmac/io/
obj.rs

1use std::fs::File;
2use std::io::{BufRead, BufReader, BufWriter, Write};
3
4use crate::error::Result;
5
6#[cfg(feature = "rayon")]
7use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
8
9/// Writes a triangle mesh to a Wavefront OBJ file.
10///
11/// The function first writes all vertex positions, then defines each triangle
12/// face by referencing the 1-based indices of its vertices. If the `rayon`
13/// feature is enabled, the face data will be formatted in parallel.
14///
15/// # Arguments
16/// * `nodes` - A slice of 3D vertex positions.
17/// * `cells` - A slice of triangles, each defined by 0-based indices into the `nodes` array.
18/// * `filename` - Optional file path. Defaults to `"mesh.obj"` if `None`.
19///
20/// # Returns
21/// Returns `Ok(())` on success, or an `std::io::Error` on failure.
22pub fn write_obj(
23    nodes: &[[f64; 3]],
24    cells: &[[usize; 3]],
25    filename: Option<&str>,
26) -> Result<()> {
27    let file = File::create(filename.unwrap_or("mesh.obj"))?;
28    let mut writer = BufWriter::new(file);
29
30    // Write Header
31    writeln!(writer, "# OBJ file generated by Rust mesh library")?;
32    writeln!(writer, "# {} vertices, {} faces", nodes.len(), cells.len())?;
33
34    // Write all vertex positions
35    for node in nodes {
36        writeln!(writer, "v {} {} {}", node[0], node[1], node[2])?;
37    }
38
39    // Compute all face strings in parallel (if enabled)
40    let process_cell = |cell: &[usize; 3]| -> String {
41        // OBJ format is 1-based, so we must add 1 to each 0-based index.
42        format!("f {} {} {}", cell[0] + 1, cell[1] + 1, cell[2] + 1)
43    };
44
45    #[cfg(feature = "rayon")]
46    let face_strings: Vec<String> = cells.par_iter().map(process_cell).collect();
47
48    #[cfg(not(feature = "rayon"))]
49    let face_strings: Vec<String> = cells.iter().map(process_cell).collect();
50
51    // Write all face definitions
52    for face_str in face_strings {
53        writeln!(writer, "{}", face_str)?;
54    }
55
56    Ok(())
57}
58
59/// Reads a Wavefront OBJ file and extracts vertex and face data.
60///
61/// This function parses lines beginning with 'v ' for vertices and 'f ' for faces.
62/// It correctly handles the 1-based indexing of the OBJ format for faces and can
63/// parse complex face definitions (e.g., `f v/vt/vn` or `f v//vn`).
64///
65/// # Arguments
66/// * `filename` - The path to the `.obj` file.
67///
68/// # Returns
69/// A `Result` containing a tuple with:
70/// - `Vec<[f64; 3]>`: The vector of vertex positions (`nodes`).
71/// - `Vec<[usize; 3]>`: The vector of triangle indices (`cells`).
72pub fn read_obj(filename: &str) -> Result<(Vec<[f64; 3]>, Vec<[usize; 3]>)> {
73    let file = File::open(filename)?;
74    let reader = BufReader::new(file);
75
76    let mut nodes = Vec::new();
77    let mut cells = Vec::new();
78
79    for line in reader.lines() {
80        let line = line?;
81        let parts: Vec<&str> = line.split_whitespace().collect();
82        if parts.is_empty() {
83            continue;
84        }
85
86        match parts[0] {
87            // Vertex position line: "v x y z"
88            "v" => {
89                if parts.len() >= 4 {
90                    let x = parts[1].parse::<f64>().unwrap_or(0.0);
91                    let y = parts[2].parse::<f64>().unwrap_or(0.0);
92                    let z = parts[3].parse::<f64>().unwrap_or(0.0);
93                    nodes.push([x, y, z]);
94                }
95            }
96            // Face definition line: "f v1 v2 v3"
97            "f" => {
98                if parts.len() >= 4 {
99                    let mut face_indices = Vec::new();
100                    for part in &parts[1..] {
101                        // Handle complex definitions like "f v/vt/vn" by splitting on '/'
102                        // and taking the first part (the vertex index).
103                        let index_str = part.split('/').next().unwrap_or("");
104                        if let Ok(index) = index_str.parse::<usize>() {
105                            // OBJ is 1-based, so subtract 1 for 0-based index.
106                            face_indices.push(index - 1);
107                        }
108                    }
109
110                    // Triangulate if the face is a quad.
111                    if face_indices.len() == 3 {
112                        cells.push([face_indices[0], face_indices[1], face_indices[2]]);
113                    } else if face_indices.len() == 4 {
114                        cells.push([face_indices[0], face_indices[1], face_indices[2]]);
115                        cells.push([face_indices[0], face_indices[2], face_indices[3]]);
116                    }
117                }
118            }
119            // Ignore other lines like comments, normals, texture coords, etc.
120            _ => {}
121        }
122    }
123
124    Ok((nodes, cells))
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    use std::fs::{remove_file, read_to_string};
131    use std::io::Write;
132
133    // A helper function to create a temporary file for testing.
134    fn create_temp_file(content: &str) -> String {
135        let filename = format!(
136            "temp_{}.obj",
137            std::time::SystemTime::now()
138                .duration_since(std::time::UNIX_EPOCH)
139                .unwrap()
140                .as_nanos()
141        );
142        let mut file = File::create(&filename).unwrap();
143        writeln!(file, "{}", content).unwrap();
144        filename
145    }
146
147    #[test]
148    fn test_write_obj_simple() {
149        let nodes = vec![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
150        let cells = vec![[0, 1, 2]];
151        let filename = "test_write_simple.obj";
152
153        let result = write_obj(&nodes, &cells, Some(filename));
154        assert!(result.is_ok());
155
156        let content = read_to_string(filename).unwrap();
157        assert!(content.contains("v 1 2 3"));
158        assert!(content.contains("v 4 5 6"));
159        assert!(content.contains("v 7 8 9"));
160        // OBJ is 1-based, so indices are incremented.
161        assert!(content.contains("f 1 2 3"));
162
163        // Clean up the test file.
164        remove_file(filename).unwrap();
165    }
166
167    #[test]
168    fn test_read_obj_simple() {
169        let obj_content = "
170# Test comment
171v 1.0 0.0 0.0
172v 0.0 1.0 0.0
173v 0.0 0.0 1.0
174
175f 1 2 3
176";
177        let filename = create_temp_file(obj_content);
178        let (nodes, cells) = read_obj(&filename).unwrap();
179
180        assert_eq!(nodes.len(), 3);
181        assert_eq!(cells.len(), 1);
182        assert_eq!(nodes[0], [1.0, 0.0, 0.0]);
183        // Remember that read_obj converts from 1-based to 0-based indices.
184        assert_eq!(cells[0], [0, 1, 2]);
185
186        remove_file(filename).unwrap();
187    }
188
189    #[test]
190    fn test_read_obj_complex_faces() {
191        let obj_content = "
192v 1.0 0.0 0.0
193v 0.0 1.0 0.0
194v 0.0 0.0 1.0
195vt 0.0 0.0
196vn 0.0 1.0 0.0
197
198f 1/1/1 2/1/1 3/1/1
199";
200        let filename = create_temp_file(obj_content);
201        let (nodes, cells) = read_obj(&filename).unwrap();
202
203        assert_eq!(nodes.len(), 3);
204        assert_eq!(cells.len(), 1);
205        assert_eq!(cells[0], [0, 1, 2], "Should correctly parse v/vt/vn format");
206
207        remove_file(filename).unwrap();
208    }
209
210    #[test]
211    fn test_read_obj_quad_triangulation() {
212        let obj_content = "
213v 0.0 0.0 0.0
214v 1.0 0.0 0.0
215v 1.0 1.0 0.0
216v 0.0 1.0 0.0
217
218# A single quad face
219f 1 2 3 4
220";
221        let filename = create_temp_file(obj_content);
222        let (_, cells) = read_obj(&filename).unwrap();
223
224        assert_eq!(
225            cells.len(),
226            2,
227            "A quad should be triangulated into two faces"
228        );
229        assert_eq!(cells[0], [0, 1, 2]);
230        assert_eq!(cells[1], [0, 2, 3]);
231
232        remove_file(filename).unwrap();
233    }
234}