Skip to main content

sog_decoder/
decode.rs

1use crate::error::{DecodeError, DecodeResult, ParseError, ParseResult, Result, UnzipResult};
2use crate::metajson::MetaJsonType;
3use crate::types::{Means, Quats, Scales, Sh0, ShN, SogDataV2, Splat};
4use image_webp::WebPDecoder;
5use std::collections::HashMap;
6use std::io::{Cursor, Read};
7use zip::ZipArchive;
8use zip::result::ZipError;
9
10/// Unzip a zip file and return a HashMap of file names and their contents.
11fn unzip(file_data: &[u8]) -> UnzipResult<HashMap<String, Vec<u8>>> {
12    let cursor = Cursor::new(file_data);
13    let mut archive = ZipArchive::new(cursor)?;
14    let mut files = HashMap::new();
15
16    for i in 0..archive.len() {
17        let mut zip_file = archive.by_index(i)?;
18        let mut buf = Vec::with_capacity(zip_file.size() as usize);
19        let _size = zip_file.read_to_end(&mut buf).map_err(ZipError::Io)?;
20        files.insert(zip_file.name().to_owned(), buf);
21    }
22
23    Ok(files)
24}
25
26fn parse_sog(files: HashMap<String, Vec<u8>>) -> ParseResult<SogDataV2> {
27    let meta_bytes = files.get("meta.json").ok_or(ParseError::MetaJsonNotFound)?;
28
29    let meta_json_string = str::from_utf8(meta_bytes)
30        .map_err(|_| ParseError::InvalidMetaJson("encoding is not utf8".to_string()))?;
31
32    let meta_json = serde_json::from_str::<MetaJsonType>(meta_json_string)
33        .map_err(ParseError::DeserializeMetaJson)?;
34
35    if meta_json.version != 2 {
36        return Err(ParseError::InvalidMetaJson("version is not 2".to_string()));
37    }
38
39    let means_l_name = meta_json
40        .means
41        .files
42        .first()
43        .ok_or(ParseError::InvalidMetaJson(
44            "missing means_l file name".to_string(),
45        ))?;
46    let means_u_name = meta_json
47        .means
48        .files
49        .get(1)
50        .ok_or(ParseError::InvalidMetaJson(
51            "missing means_u file name".to_string(),
52        ))?;
53    let means = Means {
54        mins: meta_json.means.mins.try_into()?,
55        maxs: meta_json.means.maxs.try_into()?,
56        means_l: files
57            .get(means_l_name)
58            .ok_or(ParseError::ImageNotFound(means_l_name.to_string()))?
59            .clone(),
60        means_u: files
61            .get(means_u_name)
62            .ok_or(ParseError::ImageNotFound(means_u_name.to_string()))?
63            .clone(),
64    };
65
66    let scales_name = meta_json
67        .scales
68        .files
69        .first()
70        .ok_or(ParseError::InvalidMetaJson(
71            "missing scales file name".to_string(),
72        ))?;
73    let scales = Scales {
74        codebook: meta_json.scales.codebook.as_slice().try_into()?,
75        scales: files
76            .get(scales_name)
77            .ok_or(ParseError::ImageNotFound(scales_name.to_string()))?
78            .clone(),
79    };
80
81    let quats_name = meta_json
82        .quats
83        .files
84        .first()
85        .ok_or(ParseError::InvalidMetaJson(
86            "missing quats file name".to_string(),
87        ))?;
88    let quats = Quats(
89        files
90            .get(quats_name)
91            .ok_or(ParseError::ImageNotFound(quats_name.to_string()))?
92            .clone(),
93    );
94
95    let sh0_name = meta_json
96        .sh0
97        .files
98        .first()
99        .ok_or(ParseError::InvalidMetaJson(
100            "missing sh0 file name".to_string(),
101        ))?;
102    let sh_0 = Sh0 {
103        codebook: meta_json.sh0.codebook.as_slice().try_into()?,
104        sh_0: files
105            .get(sh0_name)
106            .ok_or(ParseError::ImageNotFound(sh0_name.to_string()))?
107            .clone(),
108    };
109
110    let sh_n = if let Some(sh_n) = meta_json.sh_n {
111        let centroids_name = sh_n.files.first().ok_or(ParseError::InvalidMetaJson(
112            "missing centroids file name".to_string(),
113        ))?;
114        let labels_name = sh_n.files.get(1).ok_or(ParseError::InvalidMetaJson(
115            "missing labels file name".to_string(),
116        ))?;
117        Some(ShN {
118            count: sh_n.count,
119            bands: sh_n.bands,
120            codebook: sh_n.codebook.as_slice().try_into()?,
121            centroids: files
122                .get(centroids_name)
123                .ok_or(ParseError::ImageNotFound(centroids_name.to_string()))?
124                .clone(),
125            labels: files
126                .get(labels_name)
127                .ok_or(ParseError::ImageNotFound(labels_name.to_string()))?
128                .clone(),
129        })
130    } else {
131        None
132    };
133
134    Ok(SogDataV2 {
135        count: meta_json.count,
136        antialias: meta_json.antialias.unwrap_or(false),
137        means,
138        quats,
139        scales,
140        sh_0,
141        sh_n,
142    })
143}
144
145pub fn unpack(file: &[u8]) -> Result<SogDataV2> {
146    let files = unzip(file)?;
147    let sog_data = parse_sog(files)?;
148    Ok(sog_data)
149}
150
151#[allow(clippy::identity_op)]
152fn decode_positions(means: &Means, count: usize) -> DecodeResult<Vec<f32>> {
153    let Means {
154        mins,
155        maxs,
156        means_l,
157        means_u,
158    } = means;
159
160    let cursor = Cursor::new(means_l);
161    let mut decoder = WebPDecoder::new(cursor)?;
162    let output_size = decoder.output_buffer_size().ok_or_else(|| {
163        DecodeError::InvalidSize("Failed to get output buffer size of WebP image".to_string())
164    })?;
165    let mut lower_pixels = vec![0u8; output_size];
166    decoder.read_image(&mut lower_pixels)?;
167
168    let cursor = Cursor::new(means_u);
169    let mut decoder = WebPDecoder::new(cursor)?;
170    let output_size = decoder.output_buffer_size().ok_or_else(|| {
171        DecodeError::InvalidSize("Failed to get output buffer size of WebP image".to_string())
172    })?;
173    let mut upper_pixels = vec![0u8; output_size];
174    decoder.read_image(&mut upper_pixels)?;
175
176    // sanitize
177    if lower_pixels.len() != upper_pixels.len() {
178        return Err(DecodeError::InvalidSize(
179            "Lower and upper pixels have different length".to_string(),
180        ));
181    } else if lower_pixels.len() % 4 != 0 {
182        return Err(DecodeError::InvalidSize(format!(
183            "lower image size cannot be divided by 4: {}",
184            lower_pixels.len()
185        )));
186    } else if upper_pixels.len() % 4 != 0 {
187        return Err(DecodeError::InvalidSize(format!(
188            "upper image size cannot be divided by 4: {}",
189            upper_pixels.len()
190        )));
191    }
192
193    let mut positions = vec![0f32; count * 3];
194    for i in 0..count {
195        let pos_x = ((upper_pixels[i * 4 + 0] as u16) << 8) | (lower_pixels[i * 4 + 0] as u16);
196        let pos_y = ((upper_pixels[i * 4 + 1] as u16) << 8) | (lower_pixels[i * 4 + 1] as u16);
197        let pos_z = ((upper_pixels[i * 4 + 2] as u16) << 8) | (lower_pixels[i * 4 + 2] as u16);
198
199        fn lerp(a: f32, b: f32, t: f32) -> f32 {
200            a + t * (b - a)
201        }
202        fn unlog(x: f32) -> f32 {
203            f32::signum(x) * (f32::exp(f32::abs(x)) - 1.0)
204        }
205
206        positions[i * 3 + 0] = unlog(lerp(mins.x, maxs.x, pos_x as f32 / 65535.0));
207        positions[i * 3 + 1] = unlog(lerp(mins.y, maxs.y, pos_y as f32 / 65535.0));
208        positions[i * 3 + 2] = unlog(lerp(mins.z, maxs.z, pos_z as f32 / 65535.0));
209    }
210
211    Ok(positions)
212}
213
214/// return: f32(x,y,z,w)
215#[allow(clippy::identity_op)]
216fn decode_rotations(quats: &Quats, count: usize) -> DecodeResult<Vec<f32>> {
217    let cursor = Cursor::new(&quats.0);
218    let mut decoder = WebPDecoder::new(cursor)?;
219    let output_size = decoder
220        .output_buffer_size()
221        .ok_or_else(|| DecodeError::InvalidSize("cannot determine output size".to_string()))?;
222    let mut pixels = vec![0u8; output_size];
223    decoder.read_image(&mut pixels)?;
224
225    fn to_comp(x: f32) -> f32 {
226        (x / 255.0 - 0.5) * 2.0 / f32::sqrt(2.0)
227    }
228
229    let mut rotations = vec![0f32; count * 4];
230
231    for i in 0..count {
232        let a = to_comp(pixels[i * 4 + 0] as f32);
233        let b = to_comp(pixels[i * 4 + 1] as f32);
234        let c = to_comp(pixels[i * 4 + 2] as f32);
235        let m = pixels[i * 4 + 3];
236
237        if m < 252 {
238            return Err(DecodeError::InvalidData(format!(
239                "invalid rotation mode(m<252): {}, index: {}",
240                m, i
241            )));
242        }
243
244        let mode = match m - 252 {
245            0u8 => Ok(0u8),
246            1u8 => Ok(1u8),
247            2u8 => Ok(2u8),
248            3u8 => Ok(3u8),
249            _ => Err(DecodeError::InvalidData(format!(
250                "invalid rotation mode: {}",
251                pixels[i * 4 + 3] - 252
252            ))),
253        }?;
254        let d = f32::sqrt(f32::max(0.0, 1.0 - a * a - b * b - c * c));
255
256        let q = match mode {
257            0 => [d, a, b, c],
258            1 => [a, d, b, c],
259            2 => [a, b, d, c],
260            3 => [a, b, c, d],
261            _ => unreachable!(),
262        };
263        rotations[i * 4 + 0] = q[0];
264        rotations[i * 4 + 1] = q[1];
265        rotations[i * 4 + 2] = q[2];
266        rotations[i * 4 + 3] = q[3];
267    }
268
269    Ok(rotations)
270}
271
272#[allow(clippy::identity_op)]
273fn decode_scales(scales: &Scales, count: usize) -> DecodeResult<Vec<f32>> {
274    let Scales { codebook, scales } = scales;
275
276    let cursor = Cursor::new(scales);
277    let mut decoder = WebPDecoder::new(cursor)?;
278    let output_size = decoder
279        .output_buffer_size()
280        .ok_or_else(|| DecodeError::InvalidSize("cannot determine output size".to_string()))?;
281    let mut pixels = vec![0u8; output_size];
282    decoder.read_image(&mut pixels)?;
283
284    if pixels.len() % 4 != 0 {
285        return Err(DecodeError::InvalidData(format!(
286            "scale image size cannot be divided by 4: {}",
287            pixels.len()
288        )));
289    }
290
291    let mut scales = vec![0f32; count * 3];
292    for i in 0..count {
293        scales[i * 3 + 0] = codebook.0[pixels[i * 4 + 0] as usize];
294        scales[i * 3 + 1] = codebook.0[pixels[i * 4 + 1] as usize];
295        scales[i * 3 + 2] = codebook.0[pixels[i * 4 + 2] as usize];
296    }
297
298    Ok(scales)
299}
300
301#[allow(clippy::identity_op)]
302fn decode_sh_0(sh0: &Sh0, count: usize) -> DecodeResult<Vec<f32>> {
303    // const SH_C0: f32 = 0.28209479177387814; // SH_C0 = Y_0^0 = 1 / (2 * sqrt(pi))
304
305    let Sh0 {
306        codebook,
307        sh_0: sh0,
308    } = sh0;
309
310    let cursor = Cursor::new(sh0);
311    let mut decoder = WebPDecoder::new(cursor)?;
312    let output_size = decoder
313        .output_buffer_size()
314        .ok_or_else(|| DecodeError::InvalidSize("cannot determine output size".to_string()))?;
315    let mut pixels = vec![0u8; output_size];
316    decoder.read_image(&mut pixels)?;
317
318    if pixels.len() % 4 != 0 {
319        return Err(DecodeError::InvalidData(format!(
320            "color image size cannot be divided by 4: {}",
321            pixels.len()
322        )));
323    }
324
325    // https://github.com/playcanvas/splat-transform/blob/930a9aec511af3665240589b9cf1727d5dcd2eac/src/lib/readers/read-sog.ts#L174
326    fn sigmoid_inv(y: f32) -> f32 {
327        let e = y.clamp(1e-6, 1.0 - 1e-6);
328        (e / (1.0 - e)).ln()
329    }
330
331    let mut colors = vec![0f32; count * 4];
332    for i in 0..count {
333        // colors[i * 4 + 0] = SH_C0 * codebook.0[pixels[i * 4 + 0] as usize] + 0.5;
334        // colors[i * 4 + 1] = SH_C0 * codebook.0[pixels[i * 4 + 1] as usize] + 0.5;
335        // colors[i * 4 + 2] = SH_C0 * codebook.0[pixels[i * 4 + 2] as usize] + 0.5;
336        colors[i * 4 + 0] = codebook.0[pixels[i * 4 + 0] as usize];
337        colors[i * 4 + 1] = codebook.0[pixels[i * 4 + 1] as usize];
338        colors[i * 4 + 2] = codebook.0[pixels[i * 4 + 2] as usize];
339        colors[i * 4 + 3] = sigmoid_inv(pixels[i * 4 + 3] as f32 / 255.0);
340    }
341
342    Ok(colors)
343}
344
345#[allow(clippy::identity_op)]
346fn decode_sh_n(sh_n: &ShN, count: usize) -> DecodeResult<Vec<f32>> {
347    let ShN {
348        bands,
349        codebook,
350        centroids,
351        labels,
352        count: _,
353    } = sh_n;
354
355    if *bands <= 0 || *bands >= 4 {
356        return Err(DecodeError::InvalidSize(format!(
357            "invalid bands count: {}",
358            bands
359        )));
360    }
361
362    let cursor = Cursor::new(centroids);
363    let mut decoder = WebPDecoder::new(cursor)?;
364    let output_size = decoder
365        .output_buffer_size()
366        .ok_or_else(|| DecodeError::InvalidSize("cannot determine output size".to_string()))?;
367    let mut centroids_pixels = vec![0u8; output_size];
368    decoder.read_image(&mut centroids_pixels)?;
369
370    let cursor = Cursor::new(labels);
371    let mut decoder = WebPDecoder::new(cursor)?;
372    let output_size = decoder
373        .output_buffer_size()
374        .ok_or_else(|| DecodeError::InvalidSize("cannot determine output size".to_string()))?;
375    let mut labels_pixels = vec![0u8; output_size];
376    decoder.read_image(&mut labels_pixels)?;
377
378    if centroids_pixels.len() % 3 != 0 || labels_pixels.len() % 4 != 0 {
379        return Err(DecodeError::InvalidSize(
380            "invalid image dimensions".to_string(),
381        ));
382    }
383
384    // calc number of coefficients
385    let coeff_count = match bands {
386        1 => 3,
387        2 => 8,
388        3 => 15,
389        _ => Err(DecodeError::InvalidData(format!(
390            "invalid sh bands:{}",
391            bands
392        )))?,
393    };
394
395    let mut sh_n_s = vec![0f32; count * coeff_count * 3];
396    for splat_index in 0..count {
397        let palette_index = ((labels_pixels[splat_index * 4 + 0] as u16)
398            | ((labels_pixels[splat_index * 4 + 1] as u16) << 8))
399            as usize;
400
401        for i in 0..3 {
402            for coeff_index in 0..coeff_count {
403                let index = (splat_index * 3 + i) * coeff_count + coeff_index;
404                let index2 = (palette_index * coeff_count + coeff_index) * 3 + i;
405                sh_n_s[index] = codebook.0[centroids_pixels[index2] as usize];
406            }
407        }
408    }
409
410    Ok(sh_n_s)
411}
412
413pub fn decode(sog_data: &SogDataV2) -> Result<Splat> {
414    let SogDataV2 {
415        means,
416        quats,
417        scales,
418        sh_0,
419        sh_n,
420        ..
421    } = sog_data;
422
423    let count = sog_data.count as usize;
424
425    let splat = Splat {
426        position: decode_positions(means, count)?,
427        rotation: decode_rotations(quats, count)?,
428        scale: decode_scales(scales, count)?,
429        sh_0: decode_sh_0(sh_0, count)?,
430        sh_n: if let Some(s) = sh_n {
431            Some(decode_sh_n(s, count)?)
432        } else {
433            None
434        },
435        count,
436        antialias: sog_data.antialias,
437        sh_degree: sh_n.as_ref().map(|s| s.bands as usize).unwrap_or(0usize),
438    };
439
440    Ok(splat)
441}