1use crate::error::{DecodeError, DecodeResult, ParseError, ParseResult, Result, UnzipResult};
2use crate::metajson::MetaJsonType;
3use crate::types::{Means, Quats, Scales, Sh0, ShN, SogDataV2, Splat};
4use image_webp::WebPDecoder;
5use std::collections::HashMap;
6use std::io::{Cursor, Read};
7use zip::ZipArchive;
8use zip::result::ZipError;
9
10fn unzip(file_data: &[u8]) -> UnzipResult<HashMap<String, Vec<u8>>> {
12 let cursor = Cursor::new(file_data);
13 let mut archive = ZipArchive::new(cursor)?;
14 let mut files = HashMap::new();
15
16 for i in 0..archive.len() {
17 let mut zip_file = archive.by_index(i)?;
18 let mut buf = Vec::with_capacity(zip_file.size() as usize);
19 let _size = zip_file.read_to_end(&mut buf).map_err(ZipError::Io)?;
20 files.insert(zip_file.name().to_owned(), buf);
21 }
22
23 Ok(files)
24}
25
26fn parse_sog(files: HashMap<String, Vec<u8>>) -> ParseResult<SogDataV2> {
27 let meta_bytes = files.get("meta.json").ok_or(ParseError::MetaJsonNotFound)?;
28
29 let meta_json_string = str::from_utf8(meta_bytes)
30 .map_err(|_| ParseError::InvalidMetaJson("encoding is not utf8".to_string()))?;
31
32 let meta_json = serde_json::from_str::<MetaJsonType>(meta_json_string)
33 .map_err(ParseError::DeserializeMetaJson)?;
34
35 if meta_json.version != 2 {
36 return Err(ParseError::InvalidMetaJson("version is not 2".to_string()));
37 }
38
39 let means_l_name = meta_json
40 .means
41 .files
42 .first()
43 .ok_or(ParseError::InvalidMetaJson(
44 "missing means_l file name".to_string(),
45 ))?;
46 let means_u_name = meta_json
47 .means
48 .files
49 .get(1)
50 .ok_or(ParseError::InvalidMetaJson(
51 "missing means_u file name".to_string(),
52 ))?;
53 let means = Means {
54 mins: meta_json.means.mins.try_into()?,
55 maxs: meta_json.means.maxs.try_into()?,
56 means_l: files
57 .get(means_l_name)
58 .ok_or(ParseError::ImageNotFound(means_l_name.to_string()))?
59 .clone(),
60 means_u: files
61 .get(means_u_name)
62 .ok_or(ParseError::ImageNotFound(means_u_name.to_string()))?
63 .clone(),
64 };
65
66 let scales_name = meta_json
67 .scales
68 .files
69 .first()
70 .ok_or(ParseError::InvalidMetaJson(
71 "missing scales file name".to_string(),
72 ))?;
73 let scales = Scales {
74 codebook: meta_json.scales.codebook.as_slice().try_into()?,
75 scales: files
76 .get(scales_name)
77 .ok_or(ParseError::ImageNotFound(scales_name.to_string()))?
78 .clone(),
79 };
80
81 let quats_name = meta_json
82 .quats
83 .files
84 .first()
85 .ok_or(ParseError::InvalidMetaJson(
86 "missing quats file name".to_string(),
87 ))?;
88 let quats = Quats(
89 files
90 .get(quats_name)
91 .ok_or(ParseError::ImageNotFound(quats_name.to_string()))?
92 .clone(),
93 );
94
95 let sh0_name = meta_json
96 .sh0
97 .files
98 .first()
99 .ok_or(ParseError::InvalidMetaJson(
100 "missing sh0 file name".to_string(),
101 ))?;
102 let sh_0 = Sh0 {
103 codebook: meta_json.sh0.codebook.as_slice().try_into()?,
104 sh_0: files
105 .get(sh0_name)
106 .ok_or(ParseError::ImageNotFound(sh0_name.to_string()))?
107 .clone(),
108 };
109
110 let sh_n = if let Some(sh_n) = meta_json.sh_n {
111 let centroids_name = sh_n.files.first().ok_or(ParseError::InvalidMetaJson(
112 "missing centroids file name".to_string(),
113 ))?;
114 let labels_name = sh_n.files.get(1).ok_or(ParseError::InvalidMetaJson(
115 "missing labels file name".to_string(),
116 ))?;
117 Some(ShN {
118 count: sh_n.count,
119 bands: sh_n.bands,
120 codebook: sh_n.codebook.as_slice().try_into()?,
121 centroids: files
122 .get(centroids_name)
123 .ok_or(ParseError::ImageNotFound(centroids_name.to_string()))?
124 .clone(),
125 labels: files
126 .get(labels_name)
127 .ok_or(ParseError::ImageNotFound(labels_name.to_string()))?
128 .clone(),
129 })
130 } else {
131 None
132 };
133
134 Ok(SogDataV2 {
135 count: meta_json.count,
136 antialias: meta_json.antialias.unwrap_or(false),
137 means,
138 quats,
139 scales,
140 sh_0,
141 sh_n,
142 })
143}
144
145pub fn unpack(file: &[u8]) -> Result<SogDataV2> {
146 let files = unzip(file)?;
147 let sog_data = parse_sog(files)?;
148 Ok(sog_data)
149}
150
151#[allow(clippy::identity_op)]
152fn decode_positions(means: &Means, count: usize) -> DecodeResult<Vec<f32>> {
153 let Means {
154 mins,
155 maxs,
156 means_l,
157 means_u,
158 } = means;
159
160 let cursor = Cursor::new(means_l);
161 let mut decoder = WebPDecoder::new(cursor)?;
162 let output_size = decoder.output_buffer_size().ok_or_else(|| {
163 DecodeError::InvalidSize("Failed to get output buffer size of WebP image".to_string())
164 })?;
165 let mut lower_pixels = vec![0u8; output_size];
166 decoder.read_image(&mut lower_pixels)?;
167
168 let cursor = Cursor::new(means_u);
169 let mut decoder = WebPDecoder::new(cursor)?;
170 let output_size = decoder.output_buffer_size().ok_or_else(|| {
171 DecodeError::InvalidSize("Failed to get output buffer size of WebP image".to_string())
172 })?;
173 let mut upper_pixels = vec![0u8; output_size];
174 decoder.read_image(&mut upper_pixels)?;
175
176 if lower_pixels.len() != upper_pixels.len() {
178 return Err(DecodeError::InvalidSize(
179 "Lower and upper pixels have different length".to_string(),
180 ));
181 } else if lower_pixels.len() % 4 != 0 {
182 return Err(DecodeError::InvalidSize(format!(
183 "lower image size cannot be divided by 4: {}",
184 lower_pixels.len()
185 )));
186 } else if upper_pixels.len() % 4 != 0 {
187 return Err(DecodeError::InvalidSize(format!(
188 "upper image size cannot be divided by 4: {}",
189 upper_pixels.len()
190 )));
191 }
192
193 let mut positions = vec![0f32; count * 3];
194 for i in 0..count {
195 let pos_x = ((upper_pixels[i * 4 + 0] as u16) << 8) | (lower_pixels[i * 4 + 0] as u16);
196 let pos_y = ((upper_pixels[i * 4 + 1] as u16) << 8) | (lower_pixels[i * 4 + 1] as u16);
197 let pos_z = ((upper_pixels[i * 4 + 2] as u16) << 8) | (lower_pixels[i * 4 + 2] as u16);
198
199 fn lerp(a: f32, b: f32, t: f32) -> f32 {
200 a + t * (b - a)
201 }
202 fn unlog(x: f32) -> f32 {
203 f32::signum(x) * (f32::exp(f32::abs(x)) - 1.0)
204 }
205
206 positions[i * 3 + 0] = unlog(lerp(mins.x, maxs.x, pos_x as f32 / 65535.0));
207 positions[i * 3 + 1] = unlog(lerp(mins.y, maxs.y, pos_y as f32 / 65535.0));
208 positions[i * 3 + 2] = unlog(lerp(mins.z, maxs.z, pos_z as f32 / 65535.0));
209 }
210
211 Ok(positions)
212}
213
214#[allow(clippy::identity_op)]
216fn decode_rotations(quats: &Quats, count: usize) -> DecodeResult<Vec<f32>> {
217 let cursor = Cursor::new(&quats.0);
218 let mut decoder = WebPDecoder::new(cursor)?;
219 let output_size = decoder
220 .output_buffer_size()
221 .ok_or_else(|| DecodeError::InvalidSize("cannot determine output size".to_string()))?;
222 let mut pixels = vec![0u8; output_size];
223 decoder.read_image(&mut pixels)?;
224
225 fn to_comp(x: f32) -> f32 {
226 (x / 255.0 - 0.5) * 2.0 / f32::sqrt(2.0)
227 }
228
229 let mut rotations = vec![0f32; count * 4];
230
231 for i in 0..count {
232 let a = to_comp(pixels[i * 4 + 0] as f32);
233 let b = to_comp(pixels[i * 4 + 1] as f32);
234 let c = to_comp(pixels[i * 4 + 2] as f32);
235 let m = pixels[i * 4 + 3];
236
237 if m < 252 {
238 return Err(DecodeError::InvalidData(format!(
239 "invalid rotation mode(m<252): {}, index: {}",
240 m, i
241 )));
242 }
243
244 let mode = match m - 252 {
245 0u8 => Ok(0u8),
246 1u8 => Ok(1u8),
247 2u8 => Ok(2u8),
248 3u8 => Ok(3u8),
249 _ => Err(DecodeError::InvalidData(format!(
250 "invalid rotation mode: {}",
251 pixels[i * 4 + 3] - 252
252 ))),
253 }?;
254 let d = f32::sqrt(f32::max(0.0, 1.0 - a * a - b * b - c * c));
255
256 let q = match mode {
257 0 => [d, a, b, c],
258 1 => [a, d, b, c],
259 2 => [a, b, d, c],
260 3 => [a, b, c, d],
261 _ => unreachable!(),
262 };
263 rotations[i * 4 + 0] = q[0];
264 rotations[i * 4 + 1] = q[1];
265 rotations[i * 4 + 2] = q[2];
266 rotations[i * 4 + 3] = q[3];
267 }
268
269 Ok(rotations)
270}
271
272#[allow(clippy::identity_op)]
273fn decode_scales(scales: &Scales, count: usize) -> DecodeResult<Vec<f32>> {
274 let Scales { codebook, scales } = scales;
275
276 let cursor = Cursor::new(scales);
277 let mut decoder = WebPDecoder::new(cursor)?;
278 let output_size = decoder
279 .output_buffer_size()
280 .ok_or_else(|| DecodeError::InvalidSize("cannot determine output size".to_string()))?;
281 let mut pixels = vec![0u8; output_size];
282 decoder.read_image(&mut pixels)?;
283
284 if pixels.len() % 4 != 0 {
285 return Err(DecodeError::InvalidData(format!(
286 "scale image size cannot be divided by 4: {}",
287 pixels.len()
288 )));
289 }
290
291 let mut scales = vec![0f32; count * 3];
292 for i in 0..count {
293 scales[i * 3 + 0] = codebook.0[pixels[i * 4 + 0] as usize];
294 scales[i * 3 + 1] = codebook.0[pixels[i * 4 + 1] as usize];
295 scales[i * 3 + 2] = codebook.0[pixels[i * 4 + 2] as usize];
296 }
297
298 Ok(scales)
299}
300
301#[allow(clippy::identity_op)]
302fn decode_sh_0(sh0: &Sh0, count: usize) -> DecodeResult<Vec<f32>> {
303 let Sh0 {
306 codebook,
307 sh_0: sh0,
308 } = sh0;
309
310 let cursor = Cursor::new(sh0);
311 let mut decoder = WebPDecoder::new(cursor)?;
312 let output_size = decoder
313 .output_buffer_size()
314 .ok_or_else(|| DecodeError::InvalidSize("cannot determine output size".to_string()))?;
315 let mut pixels = vec![0u8; output_size];
316 decoder.read_image(&mut pixels)?;
317
318 if pixels.len() % 4 != 0 {
319 return Err(DecodeError::InvalidData(format!(
320 "color image size cannot be divided by 4: {}",
321 pixels.len()
322 )));
323 }
324
325 fn sigmoid_inv(y: f32) -> f32 {
327 let e = y.clamp(1e-6, 1.0 - 1e-6);
328 (e / (1.0 - e)).ln()
329 }
330
331 let mut colors = vec![0f32; count * 4];
332 for i in 0..count {
333 colors[i * 4 + 0] = codebook.0[pixels[i * 4 + 0] as usize];
337 colors[i * 4 + 1] = codebook.0[pixels[i * 4 + 1] as usize];
338 colors[i * 4 + 2] = codebook.0[pixels[i * 4 + 2] as usize];
339 colors[i * 4 + 3] = sigmoid_inv(pixels[i * 4 + 3] as f32 / 255.0);
340 }
341
342 Ok(colors)
343}
344
345#[allow(clippy::identity_op)]
346fn decode_sh_n(sh_n: &ShN, count: usize) -> DecodeResult<Vec<f32>> {
347 let ShN {
348 bands,
349 codebook,
350 centroids,
351 labels,
352 count: _,
353 } = sh_n;
354
355 if *bands <= 0 || *bands >= 4 {
356 return Err(DecodeError::InvalidSize(format!(
357 "invalid bands count: {}",
358 bands
359 )));
360 }
361
362 let cursor = Cursor::new(centroids);
363 let mut decoder = WebPDecoder::new(cursor)?;
364 let output_size = decoder
365 .output_buffer_size()
366 .ok_or_else(|| DecodeError::InvalidSize("cannot determine output size".to_string()))?;
367 let mut centroids_pixels = vec![0u8; output_size];
368 decoder.read_image(&mut centroids_pixels)?;
369
370 let cursor = Cursor::new(labels);
371 let mut decoder = WebPDecoder::new(cursor)?;
372 let output_size = decoder
373 .output_buffer_size()
374 .ok_or_else(|| DecodeError::InvalidSize("cannot determine output size".to_string()))?;
375 let mut labels_pixels = vec![0u8; output_size];
376 decoder.read_image(&mut labels_pixels)?;
377
378 if centroids_pixels.len() % 3 != 0 || labels_pixels.len() % 4 != 0 {
379 return Err(DecodeError::InvalidSize(
380 "invalid image dimensions".to_string(),
381 ));
382 }
383
384 let coeff_count = match bands {
386 1 => 3,
387 2 => 8,
388 3 => 15,
389 _ => Err(DecodeError::InvalidData(format!(
390 "invalid sh bands:{}",
391 bands
392 )))?,
393 };
394
395 let mut sh_n_s = vec![0f32; count * coeff_count * 3];
396 for splat_index in 0..count {
397 let palette_index = ((labels_pixels[splat_index * 4 + 0] as u16)
398 | ((labels_pixels[splat_index * 4 + 1] as u16) << 8))
399 as usize;
400
401 for i in 0..3 {
402 for coeff_index in 0..coeff_count {
403 let index = (splat_index * 3 + i) * coeff_count + coeff_index;
404 let index2 = (palette_index * coeff_count + coeff_index) * 3 + i;
405 sh_n_s[index] = codebook.0[centroids_pixels[index2] as usize];
406 }
407 }
408 }
409
410 Ok(sh_n_s)
411}
412
413pub fn decode(sog_data: &SogDataV2) -> Result<Splat> {
414 let SogDataV2 {
415 means,
416 quats,
417 scales,
418 sh_0,
419 sh_n,
420 ..
421 } = sog_data;
422
423 let count = sog_data.count as usize;
424
425 let splat = Splat {
426 position: decode_positions(means, count)?,
427 rotation: decode_rotations(quats, count)?,
428 scale: decode_scales(scales, count)?,
429 sh_0: decode_sh_0(sh_0, count)?,
430 sh_n: if let Some(s) = sh_n {
431 Some(decode_sh_n(s, count)?)
432 } else {
433 None
434 },
435 count,
436 antialias: sog_data.antialias,
437 sh_degree: sh_n.as_ref().map(|s| s.bands as usize).unwrap_or(0usize),
438 };
439
440 Ok(splat)
441}