all_is_cubes/save/
compress.rs

1use alloc::borrow::{Cow, ToOwned};
2use alloc::vec::Vec;
3use core::fmt;
4use std::io::{self, Write as _};
5
6use serde::de::Error as _;
7
8/// A slice of `T` which, when serialized, will be compressed in the gzip format.
9///
10/// To ensure portability, `T` should be endianness-independent.
11///
12/// Furthermore, if the serde serializer
13/// [`is_human_readable()`](serde::Serializer::is_human_readable), then the compressed
14/// bytes will be base64 encoded, in the hopes of producing a more compact textual result.
15/// Otherwise, they will be serialized identically to a `Vec<u8>`.
16///
17/// The serialized format includes a version tag for the compression choices.
18pub(crate) struct GzSerde<'a, T: 'static>(pub Cow<'a, [T]>)
19where
20    [T]: ToOwned;
21
22impl<T: bytemuck::NoUninit> serde::Serialize for GzSerde<'_, T> {
23    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
24    where
25        S: serde::Serializer,
26    {
27        let uncompressed_bytes = bytemuck::must_cast_slice::<T, u8>(self.0.as_ref());
28
29        let compression = flate2::Compression::fast();
30
31        if serializer.is_human_readable() {
32            let mut gz_encoder = flate2::GzBuilder::new().write(
33                base64::write::EncoderStringWriter::new(&BASE64_ENGINE),
34                compression,
35            );
36            gz_encoder.write_all(uncompressed_bytes).unwrap();
37            let b64_encoder = gz_encoder.finish().unwrap();
38            let b64string = b64_encoder.into_inner();
39
40            GzSerdeInternal::Base64Gzip(Cow::Borrowed(b64string.as_str())).serialize(serializer)
41        } else {
42            let mut gz_encoder = flate2::GzBuilder::new().write(Vec::<u8>::new(), compression);
43            gz_encoder.write_all(uncompressed_bytes).unwrap();
44            let compressed_bytes = gz_encoder.finish().unwrap();
45
46            GzSerdeInternal::Gzip(Cow::Borrowed(compressed_bytes.as_slice())).serialize(serializer)
47        }
48    }
49}
50
51impl<'de, T: bytemuck::CheckedBitPattern> serde::Deserialize<'de> for GzSerde<'_, T> {
52    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
53    where
54        D: serde::Deserializer<'de>,
55    {
56        match GzSerdeInternal::deserialize(deserializer)? {
57            GzSerdeInternal::Base64Gzip(b64string) => {
58                deserialize_common(flate2::bufread::GzDecoder::new(io::BufReader::new(
59                    base64::read::DecoderReader::new(
60                        io::Cursor::new(b64string.as_bytes()),
61                        &BASE64_ENGINE,
62                    ),
63                )))
64                .map_err(|e| D::Error::custom(format!("invalid base64+gzip data: {e}")))
65            }
66            GzSerdeInternal::Gzip(gzip_bytes) => {
67                deserialize_common(flate2::bufread::GzDecoder::new(io::Cursor::new(gzip_bytes)))
68                    .map_err(|e| D::Error::custom(format!("invalid gzip data: {e}")))
69            }
70        }
71    }
72}
73
74fn deserialize_common<T: bytemuck::CheckedBitPattern>(
75    mut r: impl io::Read,
76) -> Result<GzSerde<'static, T>, io::Error> {
77    let mut uncompressed = Vec::new();
78    r.read_to_end(&mut uncompressed)?;
79    Ok(GzSerde(Cow::Owned(
80        bytemuck::checked::try_cast_slice::<u8, T>(&uncompressed)
81            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?
82            .to_owned(),
83    )))
84}
85
86impl<T> fmt::Debug for GzSerde<'_, T>
87where
88    [T]: ToOwned,
89    T: fmt::Debug,
90{
91    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92        // Truncate because this is typically very long and not interesting.
93        // TODO: tune length
94        if self.0.len() > 16 {
95            write!(f, "GzSerde({:?}...)", &self.0[..16])
96        } else {
97            write!(f, "GzSerde({:?})", &self.0[..])
98        }
99    }
100}
101
102const BASE64_ENGINE: base64::engine::GeneralPurpose =
103    base64::engine::general_purpose::STANDARD_NO_PAD;
104
105#[derive(serde::Serialize, serde::Deserialize)]
106enum GzSerdeInternal<'a> {
107    Base64Gzip(Cow<'a, str>),
108    Gzip(Cow<'a, [u8]>),
109}
110
111/// u16, but in guaranteed little-endian, unaligned representation.
112#[derive(Copy, Clone, Debug, Default, bytemuck::Pod, bytemuck::Zeroable)]
113#[repr(transparent)]
114pub(crate) struct Leu16([u8; 2]);
115
116impl From<u16> for Leu16 {
117    fn from(value: u16) -> Self {
118        Self(value.to_le_bytes())
119    }
120}
121impl From<Leu16> for u16 {
122    fn from(value: Leu16) -> Self {
123        u16::from_le_bytes(value.0)
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    use serde_json::json;
131
132    #[track_caller]
133    fn assert_round_trip<T>(value: &[T], expected_base64: &str)
134    where
135        T: bytemuck::Pod + Eq + fmt::Debug,
136    {
137        let json_value =
138            serde_json::to_value(GzSerde(Cow::Borrowed(value))).expect("failed to serialize");
139
140        assert_eq!(
141            json_value,
142            json!({ "Base64Gzip": expected_base64 }),
143            "serialized str != expected str"
144        );
145        let deserialized =
146            serde_json::from_value::<GzSerde<'_, T>>(json_value).expect("failed to deserialize");
147        assert_eq!(
148            &deserialized.0[..],
149            value,
150            "roundtripped value not as expected"
151        );
152
153        // TODO: test non-human-readable output
154    }
155
156    #[test]
157    fn empty() {
158        // output is non-empty because it includes a gzip header
159        assert_round_trip::<[u8; 2]>(&[], "H4sIAAAAAAAE/wMAAAAAAAAAAAA");
160    }
161
162    #[test]
163    fn nonempty() {
164        assert_round_trip::<[u8; 2]>(
165            &[[1, 2], [3, 4], [5, 6]],
166            "H4sIAAAAAAAE/2NkYmZhZQMAJHf2gQYAAAA",
167        );
168    }
169
170    #[test]
171    fn proof_of_compression() {
172        assert_round_trip::<[u8; 2]>(&vec![[123, 45]; 10000], "H4sIAAAAAAAE/+3QAQ0AAAiAsEQmtLyzBvtIwHdEgAABAgQIECBAgAABAgQI1AX8ESBAgAABAgQIECBAgAABAn0BhwQIECBAgAABAgQIECBAgEBfwCEBAgQIECBAgAABAgQIECDQF3BIgAABAgQIECBAgAABAgRe4ADS7V+aIE4AAA");
173    }
174}