Skip to main content

objects/store/pack/
shared.rs

1// SPDX-License-Identifier: Apache-2.0
2use super::{ObjectType, varint};
3use crate::{
4    object::{ChangeId, ContentHash},
5    store::{Result, StoreError, compression::CompressionConfig},
6};
7
8pub const PACK_CHECKSUM_LEN: usize = 32;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]
11pub enum PackObjectId {
12    Hash(ContentHash),
13    ChangeId(ChangeId),
14}
15
16impl PackObjectId {
17    pub fn encode_tagged(self, buf: &mut Vec<u8>) {
18        match self {
19            Self::Hash(hash) => {
20                buf.push(0);
21                buf.extend_from_slice(hash.as_bytes());
22            }
23            Self::ChangeId(change_id) => {
24                buf.push(1);
25                buf.extend_from_slice(change_id.as_bytes());
26            }
27        }
28    }
29
30    pub fn decode_tagged(data: &[u8]) -> Result<(Self, usize)> {
31        let Some(tag) = data.first().copied() else {
32            return Err(StoreError::InvalidObject(
33                "missing pack object id tag".to_string(),
34            ));
35        };
36        match tag {
37            0 => {
38                if data.len() < 33 {
39                    return Err(StoreError::InvalidObject(
40                        "hash pack object id truncated".to_string(),
41                    ));
42                }
43                let hash = ContentHash::from_bytes(data[1..33].try_into().map_err(|_| {
44                    StoreError::InvalidObject("invalid hash id length".to_string())
45                })?);
46                Ok((Self::Hash(hash), 33))
47            }
48            1 => {
49                if data.len() < 17 {
50                    return Err(StoreError::InvalidObject(
51                        "change id pack object id truncated".to_string(),
52                    ));
53                }
54                let change_id = ChangeId::from_bytes(data[1..17].try_into().map_err(|_| {
55                    StoreError::InvalidObject("invalid change id length".to_string())
56                })?);
57                Ok((Self::ChangeId(change_id), 17))
58            }
59            _ => Err(StoreError::InvalidObject(format!(
60                "unknown pack object id tag {tag}"
61            ))),
62        }
63    }
64}
65
66#[derive(Debug, Clone)]
67pub struct PackObjectRecord {
68    pub id: PackObjectId,
69    pub obj_type: ObjectType,
70    pub data: Vec<u8>,
71    pub delta_base: Option<PackObjectId>,
72    pub path_hint: Option<String>,
73}
74
75#[derive(Debug, Clone, Copy)]
76pub struct PackContainerSpec {
77    pub magic: &'static [u8; 4],
78    pub version: u32,
79}
80
81#[derive(Debug, Clone)]
82pub struct PackEntryHeader {
83    pub id: PackObjectId,
84    pub obj_type: ObjectType,
85    pub uncompressed_size: usize,
86    pub compressed_size: usize,
87    pub delta_base: Option<PackObjectId>,
88    pub header_len: usize,
89}
90
91pub fn write_container_header(buf: &mut Vec<u8>, spec: PackContainerSpec, count: u64) {
92    buf.extend_from_slice(spec.magic);
93    buf.extend_from_slice(&spec.version.to_be_bytes());
94    buf.extend_from_slice(&count.to_be_bytes());
95}
96
97pub fn verify_container(data: &[u8], spec: PackContainerSpec) -> Result<(u64, usize, usize)> {
98    if data.len() < 16 + PACK_CHECKSUM_LEN {
99        return Err(StoreError::InvalidObject("Pack too short".to_string()));
100    }
101    if &data[..4] != spec.magic {
102        return Err(StoreError::InvalidObject("Invalid pack magic".to_string()));
103    }
104    let version = u32::from_be_bytes([data[4], data[5], data[6], data[7]]);
105    if version != spec.version {
106        return Err(StoreError::InvalidObject(format!(
107            "Unsupported pack version: {}",
108            version
109        )));
110    }
111
112    let content_end = data.len() - PACK_CHECKSUM_LEN;
113    let content = &data[..content_end];
114    let stored_checksum = &data[content_end..];
115    let computed_checksum = blake3::hash(content);
116    if computed_checksum.as_bytes() != stored_checksum {
117        return Err(StoreError::InvalidObject(
118            "Pack checksum mismatch".to_string(),
119        ));
120    }
121
122    let count = u64::from_be_bytes([
123        data[8], data[9], data[10], data[11], data[12], data[13], data[14], data[15],
124    ]);
125    Ok((count, 16, content_end))
126}
127
128pub fn append_container_checksum(buf: &mut Vec<u8>) {
129    let checksum = blake3::hash(buf);
130    buf.extend_from_slice(checksum.as_bytes());
131}
132
133pub fn encode_tagged_entry(
134    buf: &mut Vec<u8>,
135    record: &PackObjectRecord,
136    stored_type: ObjectType,
137    compressed: &[u8],
138) -> Result<()> {
139    encode_tagged_entry_parts(
140        buf,
141        record.id,
142        stored_type,
143        record.data.len(),
144        record.delta_base,
145        compressed,
146    )
147}
148
149pub fn encode_tagged_entry_parts(
150    buf: &mut Vec<u8>,
151    id: PackObjectId,
152    stored_type: ObjectType,
153    uncompressed_size: usize,
154    delta_base: Option<PackObjectId>,
155    compressed: &[u8],
156) -> Result<()> {
157    id.encode_tagged(buf);
158    varint::encode_type_and_size(stored_type, uncompressed_size as u64, buf);
159    varint::encode_varint(compressed.len() as u64, buf);
160    if stored_type == ObjectType::Delta {
161        let Some(base) = delta_base else {
162            return Err(StoreError::InvalidObject(
163                "Delta entry missing base id".to_string(),
164            ));
165        };
166        base.encode_tagged(buf);
167    }
168    buf.extend_from_slice(compressed);
169    Ok(())
170}
171
172pub fn decode_tagged_entry_header(data: &[u8]) -> Result<PackEntryHeader> {
173    let (id, id_len) = PackObjectId::decode_tagged(data)?;
174    let (obj_type, uncompressed_size, type_len) = varint::decode_type_and_size(&data[id_len..])
175        .ok_or_else(|| StoreError::InvalidObject("Truncated type+size varint".to_string()))?;
176    let varint_start = id_len + type_len;
177    let (compressed_size, comp_len) = varint::decode_varint(&data[varint_start..])
178        .ok_or_else(|| StoreError::InvalidObject("Truncated compressed_size varint".to_string()))?;
179    let mut header_len = varint_start + comp_len;
180
181    let delta_base = if obj_type == ObjectType::Delta {
182        let (base, base_len) = PackObjectId::decode_tagged(&data[header_len..])?;
183        header_len += base_len;
184        Some(base)
185    } else {
186        None
187    };
188
189    Ok(PackEntryHeader {
190        id,
191        obj_type,
192        uncompressed_size: uncompressed_size as usize,
193        compressed_size: compressed_size as usize,
194        delta_base,
195        header_len,
196    })
197}
198
199pub fn try_decode_tagged_entry_header(data: &[u8]) -> Result<Option<PackEntryHeader>> {
200    let Some(tag) = data.first().copied() else {
201        return Ok(None);
202    };
203
204    let (id, id_len) =
205        match tag {
206            0 => {
207                if data.len() < 33 {
208                    return Ok(None);
209                }
210                let hash = ContentHash::from_bytes(data[1..33].try_into().map_err(|_| {
211                    StoreError::InvalidObject("invalid hash id length".to_string())
212                })?);
213                (PackObjectId::Hash(hash), 33)
214            }
215            1 => {
216                if data.len() < 17 {
217                    return Ok(None);
218                }
219                let change_id = ChangeId::from_bytes(data[1..17].try_into().map_err(|_| {
220                    StoreError::InvalidObject("invalid change id length".to_string())
221                })?);
222                (PackObjectId::ChangeId(change_id), 17)
223            }
224            _ => {
225                return Err(StoreError::InvalidObject(format!(
226                    "unknown pack object id tag {tag}"
227                )));
228            }
229        };
230
231    let Some((obj_type, uncompressed_size, type_len)) =
232        varint::decode_type_and_size(&data[id_len..])
233    else {
234        return Ok(None);
235    };
236    let varint_start = id_len + type_len;
237    let Some((compressed_size, comp_len)) = varint::decode_varint(&data[varint_start..]) else {
238        return Ok(None);
239    };
240    let mut header_len = varint_start + comp_len;
241
242    let delta_base = if obj_type == ObjectType::Delta {
243        let Some(base_tag) = data.get(header_len).copied() else {
244            return Ok(None);
245        };
246        let (base, base_len) = match base_tag {
247            0 => {
248                let end = header_len + 33;
249                if data.len() < end {
250                    return Ok(None);
251                }
252                let hash = ContentHash::from_bytes(data[header_len + 1..end].try_into().map_err(
253                    |_| StoreError::InvalidObject("invalid hash id length".to_string()),
254                )?);
255                (PackObjectId::Hash(hash), 33)
256            }
257            1 => {
258                let end = header_len + 17;
259                if data.len() < end {
260                    return Ok(None);
261                }
262                let change_id =
263                    ChangeId::from_bytes(data[header_len + 1..end].try_into().map_err(|_| {
264                        StoreError::InvalidObject("invalid change id length".to_string())
265                    })?);
266                (PackObjectId::ChangeId(change_id), 17)
267            }
268            _ => {
269                return Err(StoreError::InvalidObject(format!(
270                    "unknown pack object id tag {base_tag}"
271                )));
272            }
273        };
274        header_len += base_len;
275        Some(base)
276    } else {
277        None
278    };
279
280    Ok(Some(PackEntryHeader {
281        id,
282        obj_type,
283        uncompressed_size: uncompressed_size as usize,
284        compressed_size: compressed_size as usize,
285        delta_base,
286        header_len,
287    }))
288}
289
290pub fn compress_pack_payload(data: &[u8], config: &CompressionConfig) -> Result<Vec<u8>> {
291    if !config.enabled || data.len() < config.min_size {
292        return Ok(data.to_vec());
293    }
294    #[cfg(feature = "zstd")]
295    {
296        match zstd::encode_all(data, config.level) {
297            Ok(compressed) if compressed.len() < data.len() => Ok(compressed),
298            _ => Ok(data.to_vec()),
299        }
300    }
301    #[cfg(not(feature = "zstd"))]
302    {
303        let _ = config;
304        Ok(data.to_vec())
305    }
306}
307
308pub fn decompress_pack_payload(data: &[u8], expected_size: usize) -> Result<Vec<u8>> {
309    #[cfg(feature = "zstd")]
310    {
311        use std::io::Read;
312        let mut decoder = zstd::stream::read::Decoder::new(data)
313            .map_err(|e| StoreError::InvalidObject(format!("zstd decode init failed: {e}")))?;
314        let capacity = if expected_size > 0 {
315            expected_size
316        } else {
317            data.len() * 2
318        };
319        let mut buf = Vec::with_capacity(capacity);
320        decoder
321            .read_to_end(&mut buf)
322            .map_err(|e| StoreError::InvalidObject(format!("zstd decompression failed: {e}")))?;
323        Ok(buf)
324    }
325    #[cfg(not(feature = "zstd"))]
326    {
327        let _ = expected_size;
328        Ok(data.to_vec())
329    }
330}
331
332pub fn has_zstd_magic(data: &[u8]) -> bool {
333    data.len() >= 4 && data[..4] == [0x28, 0xB5, 0x2F, 0xFD]
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339
340    #[test]
341    fn tagged_pack_object_ids_round_trip() {
342        let ids = [
343            PackObjectId::Hash(ContentHash::compute(b"hash-object")),
344            PackObjectId::ChangeId(ChangeId::generate()),
345        ];
346
347        for id in ids {
348            let mut encoded = Vec::new();
349            id.encode_tagged(&mut encoded);
350            let (decoded, consumed) = PackObjectId::decode_tagged(&encoded).unwrap();
351            assert_eq!(decoded, id);
352            assert_eq!(consumed, encoded.len());
353        }
354    }
355
356    #[test]
357    fn tagged_entry_header_round_trips_mixed_identity() {
358        let record = PackObjectRecord {
359            id: PackObjectId::ChangeId(ChangeId::generate()),
360            obj_type: ObjectType::State,
361            data: vec![1, 2, 3, 4, 5],
362            delta_base: None,
363            path_hint: None,
364        };
365
366        let mut encoded = Vec::new();
367        encode_tagged_entry(&mut encoded, &record, record.obj_type, &record.data).unwrap();
368        let decoded = decode_tagged_entry_header(&encoded).unwrap();
369
370        assert_eq!(decoded.id, record.id);
371        assert_eq!(decoded.obj_type, ObjectType::State);
372        assert_eq!(decoded.uncompressed_size, 5);
373        assert_eq!(decoded.compressed_size, 5);
374        assert_eq!(decoded.delta_base, None);
375    }
376}