Skip to main content

objects/store/pack/
shared.rs

1// SPDX-License-Identifier: Apache-2.0
2#![deny(clippy::cast_possible_truncation)]
3
4use heddle_format::compression::CompressionConfig;
5
6use super::{
7    ObjectType, varint,
8    versioned_header::{HeaderChecksum, VersionedHeader},
9};
10use crate::{
11    object::{ChangeId, ContentHash},
12    store::{Result, StoreError},
13};
14
15pub const PACK_CHECKSUM_LEN: usize = 32;
16pub const MAX_PACK_OBJECT_OUTPUT_SIZE: usize = 1024 * 1024 * 1024;
17#[cfg(feature = "zstd")]
18pub(super) const PACK_DECOMPRESSION_INITIAL_CAP: usize = 4 * 1024 * 1024;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]
21pub enum PackObjectId {
22    Hash(ContentHash),
23    ChangeId(ChangeId),
24}
25
26impl PackObjectId {
27    pub fn encode_tagged(self, buf: &mut Vec<u8>) {
28        match self {
29            Self::Hash(hash) => {
30                buf.push(0);
31                buf.extend_from_slice(hash.as_bytes());
32            }
33            Self::ChangeId(change_id) => {
34                buf.push(1);
35                buf.extend_from_slice(change_id.as_bytes());
36            }
37        }
38    }
39
40    pub fn decode_tagged(data: &[u8]) -> Result<(Self, usize)> {
41        let Some(tag) = data.first().copied() else {
42            return Err(StoreError::InvalidObject(
43                "missing pack object id tag".to_string(),
44            ));
45        };
46        match tag {
47            0 => {
48                if data.len() < 33 {
49                    return Err(StoreError::InvalidObject(
50                        "hash pack object id truncated".to_string(),
51                    ));
52                }
53                let hash = ContentHash::from_bytes(data[1..33].try_into().map_err(|_| {
54                    StoreError::InvalidObject("invalid hash id length".to_string())
55                })?);
56                Ok((Self::Hash(hash), 33))
57            }
58            1 => {
59                if data.len() < 17 {
60                    return Err(StoreError::InvalidObject(
61                        "change id pack object id truncated".to_string(),
62                    ));
63                }
64                let change_id = ChangeId::from_bytes(data[1..17].try_into().map_err(|_| {
65                    StoreError::InvalidObject("invalid change id length".to_string())
66                })?);
67                Ok((Self::ChangeId(change_id), 17))
68            }
69            _ => Err(StoreError::InvalidObject(format!(
70                "unknown pack object id tag {tag}"
71            ))),
72        }
73    }
74}
75
76#[derive(Debug, Clone)]
77pub struct PackObjectRecord {
78    pub id: PackObjectId,
79    pub obj_type: ObjectType,
80    pub data: Vec<u8>,
81    pub delta_base: Option<PackObjectId>,
82    pub path_hint: Option<String>,
83}
84
85#[derive(Debug, Clone, Copy)]
86pub struct PackContainerSpec {
87    pub magic: &'static [u8; 4],
88    pub version: u32,
89}
90
91#[derive(Debug, Clone)]
92pub struct PackEntryHeader {
93    pub id: PackObjectId,
94    pub obj_type: ObjectType,
95    pub uncompressed_size: usize,
96    pub compressed_size: usize,
97    pub delta_base: Option<PackObjectId>,
98    pub header_len: usize,
99}
100
101pub fn write_container_header(buf: &mut Vec<u8>, spec: PackContainerSpec, count: u64) {
102    pack_container_header(spec).write_vec(buf, count);
103}
104
105pub fn verify_container(data: &[u8], spec: PackContainerSpec) -> Result<(u64, usize, usize)> {
106    let header = pack_container_header(spec).verify(data)?;
107    Ok((header.count, header.header_len, header.content_end))
108}
109
110pub fn append_container_checksum(buf: &mut Vec<u8>) {
111    HeaderChecksum::Blake3Trailer.append(buf);
112}
113
114fn pack_container_header(spec: PackContainerSpec) -> VersionedHeader {
115    VersionedHeader {
116        magic: spec.magic,
117        version: spec.version,
118        checksum: HeaderChecksum::Blake3Trailer,
119        too_short: "Pack too short",
120        invalid_magic: "Invalid pack magic",
121        unsupported_version: "Unsupported pack version",
122        checksum_mismatch: "Pack checksum mismatch",
123    }
124}
125
126pub fn encode_tagged_entry(
127    buf: &mut Vec<u8>,
128    record: &PackObjectRecord,
129    stored_type: ObjectType,
130    compressed: &[u8],
131) -> Result<()> {
132    encode_tagged_entry_parts(
133        buf,
134        record.id,
135        stored_type,
136        record.data.len(),
137        record.delta_base,
138        compressed,
139    )
140}
141
142pub fn encode_tagged_entry_parts(
143    buf: &mut Vec<u8>,
144    id: PackObjectId,
145    stored_type: ObjectType,
146    uncompressed_size: usize,
147    delta_base: Option<PackObjectId>,
148    compressed: &[u8],
149) -> Result<()> {
150    id.encode_tagged(buf);
151    varint::encode_type_and_size(stored_type, uncompressed_size as u64, buf);
152    varint::encode_varint(compressed.len() as u64, buf);
153    if stored_type == ObjectType::Delta {
154        let Some(base) = delta_base else {
155            return Err(StoreError::InvalidObject(
156                "Delta entry missing base id".to_string(),
157            ));
158        };
159        base.encode_tagged(buf);
160    }
161    buf.extend_from_slice(compressed);
162    Ok(())
163}
164
165pub fn decode_tagged_entry_header(data: &[u8]) -> Result<PackEntryHeader> {
166    let (id, id_len) = PackObjectId::decode_tagged(data)?;
167    let (obj_type, uncompressed_size, type_len) = varint::decode_type_and_size(&data[id_len..])
168        .ok_or_else(|| StoreError::InvalidObject("Truncated type+size varint".to_string()))?;
169    let varint_start = id_len + type_len;
170    let (compressed_size, comp_len) = varint::decode_varint(&data[varint_start..])
171        .ok_or_else(|| StoreError::InvalidObject("Truncated compressed_size varint".to_string()))?;
172    let mut header_len = varint_start + comp_len;
173
174    let delta_base = if obj_type == ObjectType::Delta {
175        let (base, base_len) = PackObjectId::decode_tagged(&data[header_len..])?;
176        header_len += base_len;
177        Some(base)
178    } else {
179        None
180    };
181
182    Ok(PackEntryHeader {
183        id,
184        obj_type,
185        uncompressed_size: checked_decoded_size("uncompressed_size", uncompressed_size)?,
186        compressed_size: checked_decoded_size("compressed_size", compressed_size)?,
187        delta_base,
188        header_len,
189    })
190}
191
192pub fn try_decode_tagged_entry_header(data: &[u8]) -> Result<Option<PackEntryHeader>> {
193    let Some(tag) = data.first().copied() else {
194        return Ok(None);
195    };
196
197    let (id, id_len) =
198        match tag {
199            0 => {
200                if data.len() < 33 {
201                    return Ok(None);
202                }
203                let hash = ContentHash::from_bytes(data[1..33].try_into().map_err(|_| {
204                    StoreError::InvalidObject("invalid hash id length".to_string())
205                })?);
206                (PackObjectId::Hash(hash), 33)
207            }
208            1 => {
209                if data.len() < 17 {
210                    return Ok(None);
211                }
212                let change_id = ChangeId::from_bytes(data[1..17].try_into().map_err(|_| {
213                    StoreError::InvalidObject("invalid change id length".to_string())
214                })?);
215                (PackObjectId::ChangeId(change_id), 17)
216            }
217            _ => {
218                return Err(StoreError::InvalidObject(format!(
219                    "unknown pack object id tag {tag}"
220                )));
221            }
222        };
223
224    let Some((obj_type, uncompressed_size, type_len)) =
225        varint::decode_type_and_size(&data[id_len..])
226    else {
227        return Ok(None);
228    };
229    let varint_start = id_len + type_len;
230    let Some((compressed_size, comp_len)) = varint::decode_varint(&data[varint_start..]) else {
231        return Ok(None);
232    };
233    let mut header_len = varint_start + comp_len;
234
235    let delta_base = if obj_type == ObjectType::Delta {
236        let Some(base_tag) = data.get(header_len).copied() else {
237            return Ok(None);
238        };
239        let (base, base_len) = match base_tag {
240            0 => {
241                let end = header_len + 33;
242                if data.len() < end {
243                    return Ok(None);
244                }
245                let hash = ContentHash::from_bytes(data[header_len + 1..end].try_into().map_err(
246                    |_| StoreError::InvalidObject("invalid hash id length".to_string()),
247                )?);
248                (PackObjectId::Hash(hash), 33)
249            }
250            1 => {
251                let end = header_len + 17;
252                if data.len() < end {
253                    return Ok(None);
254                }
255                let change_id =
256                    ChangeId::from_bytes(data[header_len + 1..end].try_into().map_err(|_| {
257                        StoreError::InvalidObject("invalid change id length".to_string())
258                    })?);
259                (PackObjectId::ChangeId(change_id), 17)
260            }
261            _ => {
262                return Err(StoreError::InvalidObject(format!(
263                    "unknown pack object id tag {base_tag}"
264                )));
265            }
266        };
267        header_len += base_len;
268        Some(base)
269    } else {
270        None
271    };
272
273    Ok(Some(PackEntryHeader {
274        id,
275        obj_type,
276        uncompressed_size: checked_decoded_size("uncompressed_size", uncompressed_size)?,
277        compressed_size: checked_decoded_size("compressed_size", compressed_size)?,
278        delta_base,
279        header_len,
280    }))
281}
282
283fn checked_decoded_size(field: &str, size: u64) -> Result<usize> {
284    let size = usize::try_from(size).map_err(|_| {
285        StoreError::InvalidObject(format!("Decoded {field} exceeds platform limits"))
286    })?;
287    if field == "uncompressed_size" {
288        reject_pack_object_output_over_limit(size, MAX_PACK_OBJECT_OUTPUT_SIZE)?;
289    }
290    Ok(size)
291}
292
293pub fn compress_pack_payload(data: &[u8], config: &CompressionConfig) -> Result<Vec<u8>> {
294    if !config.enabled || data.len() < config.min_size {
295        return Ok(data.to_vec());
296    }
297    #[cfg(feature = "zstd")]
298    {
299        match zstd::encode_all(data, config.level) {
300            Ok(compressed) if compressed.len() < data.len() => Ok(compressed),
301            _ => Ok(data.to_vec()),
302        }
303    }
304    #[cfg(not(feature = "zstd"))]
305    {
306        let _ = config;
307        Ok(data.to_vec())
308    }
309}
310
311pub fn decompress_pack_payload(data: &[u8], expected_size: usize) -> Result<Vec<u8>> {
312    #[cfg(feature = "zstd")]
313    {
314        decompress_pack_payload_with_limit(data, expected_size, MAX_PACK_OBJECT_OUTPUT_SIZE)
315    }
316    #[cfg(not(feature = "zstd"))]
317    {
318        reject_pack_object_output_over_limit(expected_size, MAX_PACK_OBJECT_OUTPUT_SIZE)?;
319        reject_pack_object_output_over_limit(data.len(), MAX_PACK_OBJECT_OUTPUT_SIZE)?;
320        Ok(data.to_vec())
321    }
322}
323
324#[cfg(feature = "zstd")]
325pub(super) fn decompress_pack_payload_with_limit(
326    data: &[u8],
327    expected_size: usize,
328    max_output_size: usize,
329) -> Result<Vec<u8>> {
330    use std::io::Read;
331
332    // Pack objects may be raw blobs, so this bound must be materially
333    // larger than the delta-output limit. It is also intentionally
334    // above the protocol default and loose-compression cap, while
335    // still bounding one untrusted pack record to a finite allocation.
336    reject_pack_object_output_over_limit(expected_size, max_output_size)?;
337
338    let mut decoder = zstd::stream::read::Decoder::new(data)
339        .map_err(|e| StoreError::InvalidObject(format!("zstd decode init failed: {e}")))?;
340    let capacity = initial_decompression_capacity(data.len(), expected_size, max_output_size);
341    let mut buf = Vec::with_capacity(capacity);
342    let mut chunk = [0u8; 8192];
343
344    loop {
345        let bytes_read = decoder
346            .read(&mut chunk)
347            .map_err(|e| StoreError::InvalidObject(format!("zstd decompression failed: {e}")))?;
348        if bytes_read == 0 {
349            break;
350        }
351
352        let next_len = buf.len().checked_add(bytes_read).ok_or_else(|| {
353            StoreError::InvalidObject("Pack object output size overflows".to_string())
354        })?;
355        reject_pack_object_output_over_limit(next_len, max_output_size)?;
356        buf.extend_from_slice(&chunk[..bytes_read]);
357    }
358
359    Ok(buf)
360}
361
362#[cfg(feature = "zstd")]
363fn initial_decompression_capacity(
364    compressed_len: usize,
365    expected_size: usize,
366    max_output_size: usize,
367) -> usize {
368    let hint = if expected_size > 0 {
369        expected_size
370    } else {
371        compressed_len.saturating_mul(2)
372    };
373    hint.min(PACK_DECOMPRESSION_INITIAL_CAP)
374        .min(max_output_size)
375}
376
377fn reject_pack_object_output_over_limit(size: usize, max: usize) -> Result<()> {
378    if size > max {
379        return Err(StoreError::InvalidObject(format!(
380            "Pack object output size {size} exceeds max {max}"
381        )));
382    }
383    Ok(())
384}
385
386pub fn has_zstd_magic(data: &[u8]) -> bool {
387    data.len() >= 4 && data[..4] == [0x28, 0xB5, 0x2F, 0xFD]
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393
394    #[test]
395    fn tagged_pack_object_ids_round_trip() {
396        let ids = [
397            PackObjectId::Hash(ContentHash::compute(b"hash-object")),
398            PackObjectId::ChangeId(ChangeId::generate()),
399        ];
400
401        for id in ids {
402            let mut encoded = Vec::new();
403            id.encode_tagged(&mut encoded);
404            let (decoded, consumed) = PackObjectId::decode_tagged(&encoded).unwrap();
405            assert_eq!(decoded, id);
406            assert_eq!(consumed, encoded.len());
407        }
408    }
409
410    #[test]
411    fn tagged_entry_header_round_trips_mixed_identity() {
412        let record = PackObjectRecord {
413            id: PackObjectId::ChangeId(ChangeId::generate()),
414            obj_type: ObjectType::State,
415            data: vec![1, 2, 3, 4, 5],
416            delta_base: None,
417            path_hint: None,
418        };
419
420        let mut encoded = Vec::new();
421        encode_tagged_entry(&mut encoded, &record, record.obj_type, &record.data).unwrap();
422        let decoded = decode_tagged_entry_header(&encoded).unwrap();
423
424        assert_eq!(decoded.id, record.id);
425        assert_eq!(decoded.obj_type, ObjectType::State);
426        assert_eq!(decoded.uncompressed_size, 5);
427        assert_eq!(decoded.compressed_size, 5);
428        assert_eq!(decoded.delta_base, None);
429    }
430
431    #[test]
432    fn tagged_entry_header_rejects_size_that_truncates_on_32_bit() {
433        let mut encoded = Vec::new();
434        PackObjectId::Hash(ContentHash::compute(b"oversized-pack-object"))
435            .encode_tagged(&mut encoded);
436        varint::encode_type_and_size(ObjectType::Blob, u64::from(u32::MAX) + 1, &mut encoded);
437        varint::encode_varint(1, &mut encoded);
438        encoded.push(0);
439
440        let result = decode_tagged_entry_header(&encoded);
441
442        let error = result.expect_err("absurd 32-bit-overflow size must be rejected");
443        assert!(
444            matches!(&error, StoreError::InvalidObject(message) if message.contains("platform limits") || message.contains("Pack object output size")),
445            "expected size-limit InvalidObject, got: {error:?}",
446        );
447    }
448
449    #[test]
450    fn tagged_entry_header_rejects_u64_max_size_when_platform_cannot_represent_it() {
451        let mut encoded = Vec::new();
452        PackObjectId::Hash(ContentHash::compute(b"u64-max-pack-object"))
453            .encode_tagged(&mut encoded);
454        varint::encode_type_and_size(ObjectType::Blob, u64::MAX, &mut encoded);
455        varint::encode_varint(1, &mut encoded);
456        encoded.push(0);
457
458        let result = decode_tagged_entry_header(&encoded);
459
460        let error = result.expect_err("absurd u64::MAX size must be rejected");
461        assert!(
462            matches!(&error, StoreError::InvalidObject(message) if message.contains("platform limits") || message.contains("Pack object output size")),
463            "expected size-limit InvalidObject, got: {error:?}",
464        );
465    }
466}