Skip to main content

objects/store/pack/
shared.rs

1// SPDX-License-Identifier: Apache-2.0
2#![deny(clippy::cast_possible_truncation)]
3
4use super::{
5    ObjectType, varint,
6    versioned_header::{HeaderChecksum, VersionedHeader},
7};
8use crate::{
9    object::{ChangeId, ContentHash},
10    store::{Result, StoreError, compression::CompressionConfig},
11};
12
13pub const PACK_CHECKSUM_LEN: usize = 32;
14pub const MAX_PACK_OBJECT_OUTPUT_SIZE: usize = 1024 * 1024 * 1024;
15#[cfg(feature = "zstd")]
16pub(super) const PACK_DECOMPRESSION_INITIAL_CAP: usize = 4 * 1024 * 1024;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]
19pub enum PackObjectId {
20    Hash(ContentHash),
21    ChangeId(ChangeId),
22}
23
24impl PackObjectId {
25    pub fn encode_tagged(self, buf: &mut Vec<u8>) {
26        match self {
27            Self::Hash(hash) => {
28                buf.push(0);
29                buf.extend_from_slice(hash.as_bytes());
30            }
31            Self::ChangeId(change_id) => {
32                buf.push(1);
33                buf.extend_from_slice(change_id.as_bytes());
34            }
35        }
36    }
37
38    pub fn decode_tagged(data: &[u8]) -> Result<(Self, usize)> {
39        let Some(tag) = data.first().copied() else {
40            return Err(StoreError::InvalidObject(
41                "missing pack object id tag".to_string(),
42            ));
43        };
44        match tag {
45            0 => {
46                if data.len() < 33 {
47                    return Err(StoreError::InvalidObject(
48                        "hash pack object id truncated".to_string(),
49                    ));
50                }
51                let hash = ContentHash::from_bytes(data[1..33].try_into().map_err(|_| {
52                    StoreError::InvalidObject("invalid hash id length".to_string())
53                })?);
54                Ok((Self::Hash(hash), 33))
55            }
56            1 => {
57                if data.len() < 17 {
58                    return Err(StoreError::InvalidObject(
59                        "change id pack object id truncated".to_string(),
60                    ));
61                }
62                let change_id = ChangeId::from_bytes(data[1..17].try_into().map_err(|_| {
63                    StoreError::InvalidObject("invalid change id length".to_string())
64                })?);
65                Ok((Self::ChangeId(change_id), 17))
66            }
67            _ => Err(StoreError::InvalidObject(format!(
68                "unknown pack object id tag {tag}"
69            ))),
70        }
71    }
72}
73
74#[derive(Debug, Clone)]
75pub struct PackObjectRecord {
76    pub id: PackObjectId,
77    pub obj_type: ObjectType,
78    pub data: Vec<u8>,
79    pub delta_base: Option<PackObjectId>,
80    pub path_hint: Option<String>,
81}
82
83#[derive(Debug, Clone, Copy)]
84pub struct PackContainerSpec {
85    pub magic: &'static [u8; 4],
86    pub version: u32,
87}
88
89#[derive(Debug, Clone)]
90pub struct PackEntryHeader {
91    pub id: PackObjectId,
92    pub obj_type: ObjectType,
93    pub uncompressed_size: usize,
94    pub compressed_size: usize,
95    pub delta_base: Option<PackObjectId>,
96    pub header_len: usize,
97}
98
99pub fn write_container_header(buf: &mut Vec<u8>, spec: PackContainerSpec, count: u64) {
100    pack_container_header(spec).write_vec(buf, count);
101}
102
103pub fn verify_container(data: &[u8], spec: PackContainerSpec) -> Result<(u64, usize, usize)> {
104    let header = pack_container_header(spec).verify(data)?;
105    Ok((header.count, header.header_len, header.content_end))
106}
107
108pub fn append_container_checksum(buf: &mut Vec<u8>) {
109    HeaderChecksum::Blake3Trailer.append(buf);
110}
111
112fn pack_container_header(spec: PackContainerSpec) -> VersionedHeader {
113    VersionedHeader {
114        magic: spec.magic,
115        version: spec.version,
116        checksum: HeaderChecksum::Blake3Trailer,
117        too_short: "Pack too short",
118        invalid_magic: "Invalid pack magic",
119        unsupported_version: "Unsupported pack version",
120        checksum_mismatch: "Pack checksum mismatch",
121    }
122}
123
124pub fn encode_tagged_entry(
125    buf: &mut Vec<u8>,
126    record: &PackObjectRecord,
127    stored_type: ObjectType,
128    compressed: &[u8],
129) -> Result<()> {
130    encode_tagged_entry_parts(
131        buf,
132        record.id,
133        stored_type,
134        record.data.len(),
135        record.delta_base,
136        compressed,
137    )
138}
139
140pub fn encode_tagged_entry_parts(
141    buf: &mut Vec<u8>,
142    id: PackObjectId,
143    stored_type: ObjectType,
144    uncompressed_size: usize,
145    delta_base: Option<PackObjectId>,
146    compressed: &[u8],
147) -> Result<()> {
148    id.encode_tagged(buf);
149    varint::encode_type_and_size(stored_type, uncompressed_size as u64, buf);
150    varint::encode_varint(compressed.len() as u64, buf);
151    if stored_type == ObjectType::Delta {
152        let Some(base) = delta_base else {
153            return Err(StoreError::InvalidObject(
154                "Delta entry missing base id".to_string(),
155            ));
156        };
157        base.encode_tagged(buf);
158    }
159    buf.extend_from_slice(compressed);
160    Ok(())
161}
162
163pub fn decode_tagged_entry_header(data: &[u8]) -> Result<PackEntryHeader> {
164    let (id, id_len) = PackObjectId::decode_tagged(data)?;
165    let (obj_type, uncompressed_size, type_len) = varint::decode_type_and_size(&data[id_len..])
166        .ok_or_else(|| StoreError::InvalidObject("Truncated type+size varint".to_string()))?;
167    let varint_start = id_len + type_len;
168    let (compressed_size, comp_len) = varint::decode_varint(&data[varint_start..])
169        .ok_or_else(|| StoreError::InvalidObject("Truncated compressed_size varint".to_string()))?;
170    let mut header_len = varint_start + comp_len;
171
172    let delta_base = if obj_type == ObjectType::Delta {
173        let (base, base_len) = PackObjectId::decode_tagged(&data[header_len..])?;
174        header_len += base_len;
175        Some(base)
176    } else {
177        None
178    };
179
180    Ok(PackEntryHeader {
181        id,
182        obj_type,
183        uncompressed_size: checked_decoded_size("uncompressed_size", uncompressed_size)?,
184        compressed_size: checked_decoded_size("compressed_size", compressed_size)?,
185        delta_base,
186        header_len,
187    })
188}
189
190pub fn try_decode_tagged_entry_header(data: &[u8]) -> Result<Option<PackEntryHeader>> {
191    let Some(tag) = data.first().copied() else {
192        return Ok(None);
193    };
194
195    let (id, id_len) =
196        match tag {
197            0 => {
198                if data.len() < 33 {
199                    return Ok(None);
200                }
201                let hash = ContentHash::from_bytes(data[1..33].try_into().map_err(|_| {
202                    StoreError::InvalidObject("invalid hash id length".to_string())
203                })?);
204                (PackObjectId::Hash(hash), 33)
205            }
206            1 => {
207                if data.len() < 17 {
208                    return Ok(None);
209                }
210                let change_id = ChangeId::from_bytes(data[1..17].try_into().map_err(|_| {
211                    StoreError::InvalidObject("invalid change id length".to_string())
212                })?);
213                (PackObjectId::ChangeId(change_id), 17)
214            }
215            _ => {
216                return Err(StoreError::InvalidObject(format!(
217                    "unknown pack object id tag {tag}"
218                )));
219            }
220        };
221
222    let Some((obj_type, uncompressed_size, type_len)) =
223        varint::decode_type_and_size(&data[id_len..])
224    else {
225        return Ok(None);
226    };
227    let varint_start = id_len + type_len;
228    let Some((compressed_size, comp_len)) = varint::decode_varint(&data[varint_start..]) else {
229        return Ok(None);
230    };
231    let mut header_len = varint_start + comp_len;
232
233    let delta_base = if obj_type == ObjectType::Delta {
234        let Some(base_tag) = data.get(header_len).copied() else {
235            return Ok(None);
236        };
237        let (base, base_len) = match base_tag {
238            0 => {
239                let end = header_len + 33;
240                if data.len() < end {
241                    return Ok(None);
242                }
243                let hash = ContentHash::from_bytes(data[header_len + 1..end].try_into().map_err(
244                    |_| StoreError::InvalidObject("invalid hash id length".to_string()),
245                )?);
246                (PackObjectId::Hash(hash), 33)
247            }
248            1 => {
249                let end = header_len + 17;
250                if data.len() < end {
251                    return Ok(None);
252                }
253                let change_id =
254                    ChangeId::from_bytes(data[header_len + 1..end].try_into().map_err(|_| {
255                        StoreError::InvalidObject("invalid change id length".to_string())
256                    })?);
257                (PackObjectId::ChangeId(change_id), 17)
258            }
259            _ => {
260                return Err(StoreError::InvalidObject(format!(
261                    "unknown pack object id tag {base_tag}"
262                )));
263            }
264        };
265        header_len += base_len;
266        Some(base)
267    } else {
268        None
269    };
270
271    Ok(Some(PackEntryHeader {
272        id,
273        obj_type,
274        uncompressed_size: checked_decoded_size("uncompressed_size", uncompressed_size)?,
275        compressed_size: checked_decoded_size("compressed_size", compressed_size)?,
276        delta_base,
277        header_len,
278    }))
279}
280
281fn checked_decoded_size(field: &str, size: u64) -> Result<usize> {
282    let size = usize::try_from(size).map_err(|_| {
283        StoreError::InvalidObject(format!("Decoded {field} exceeds platform limits"))
284    })?;
285    if field == "uncompressed_size" {
286        reject_pack_object_output_over_limit(size, MAX_PACK_OBJECT_OUTPUT_SIZE)?;
287    }
288    Ok(size)
289}
290
291pub fn compress_pack_payload(data: &[u8], config: &CompressionConfig) -> Result<Vec<u8>> {
292    if !config.enabled || data.len() < config.min_size {
293        return Ok(data.to_vec());
294    }
295    #[cfg(feature = "zstd")]
296    {
297        match zstd::encode_all(data, config.level) {
298            Ok(compressed) if compressed.len() < data.len() => Ok(compressed),
299            _ => Ok(data.to_vec()),
300        }
301    }
302    #[cfg(not(feature = "zstd"))]
303    {
304        let _ = config;
305        Ok(data.to_vec())
306    }
307}
308
309pub fn decompress_pack_payload(data: &[u8], expected_size: usize) -> Result<Vec<u8>> {
310    #[cfg(feature = "zstd")]
311    {
312        decompress_pack_payload_with_limit(data, expected_size, MAX_PACK_OBJECT_OUTPUT_SIZE)
313    }
314    #[cfg(not(feature = "zstd"))]
315    {
316        reject_pack_object_output_over_limit(expected_size, MAX_PACK_OBJECT_OUTPUT_SIZE)?;
317        reject_pack_object_output_over_limit(data.len(), MAX_PACK_OBJECT_OUTPUT_SIZE)?;
318        Ok(data.to_vec())
319    }
320}
321
322#[cfg(feature = "zstd")]
323pub(super) fn decompress_pack_payload_with_limit(
324    data: &[u8],
325    expected_size: usize,
326    max_output_size: usize,
327) -> Result<Vec<u8>> {
328    use std::io::Read;
329
330    // Pack objects may be raw blobs, so this bound must be materially
331    // larger than the delta-output limit. It is also intentionally
332    // above the protocol default and loose-compression cap, while
333    // still bounding one untrusted pack record to a finite allocation.
334    reject_pack_object_output_over_limit(expected_size, max_output_size)?;
335
336    let mut decoder = zstd::stream::read::Decoder::new(data)
337        .map_err(|e| StoreError::InvalidObject(format!("zstd decode init failed: {e}")))?;
338    let capacity = initial_decompression_capacity(data.len(), expected_size, max_output_size);
339    let mut buf = Vec::with_capacity(capacity);
340    let mut chunk = [0u8; 8192];
341
342    loop {
343        let bytes_read = decoder
344            .read(&mut chunk)
345            .map_err(|e| StoreError::InvalidObject(format!("zstd decompression failed: {e}")))?;
346        if bytes_read == 0 {
347            break;
348        }
349
350        let next_len = buf.len().checked_add(bytes_read).ok_or_else(|| {
351            StoreError::InvalidObject("Pack object output size overflows".to_string())
352        })?;
353        reject_pack_object_output_over_limit(next_len, max_output_size)?;
354        buf.extend_from_slice(&chunk[..bytes_read]);
355    }
356
357    Ok(buf)
358}
359
360#[cfg(feature = "zstd")]
361fn initial_decompression_capacity(
362    compressed_len: usize,
363    expected_size: usize,
364    max_output_size: usize,
365) -> usize {
366    let hint = if expected_size > 0 {
367        expected_size
368    } else {
369        compressed_len.saturating_mul(2)
370    };
371    hint.min(PACK_DECOMPRESSION_INITIAL_CAP)
372        .min(max_output_size)
373}
374
375fn reject_pack_object_output_over_limit(size: usize, max: usize) -> Result<()> {
376    if size > max {
377        return Err(StoreError::InvalidObject(format!(
378            "Pack object output size {size} exceeds max {max}"
379        )));
380    }
381    Ok(())
382}
383
384pub fn has_zstd_magic(data: &[u8]) -> bool {
385    data.len() >= 4 && data[..4] == [0x28, 0xB5, 0x2F, 0xFD]
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391
392    #[test]
393    fn tagged_pack_object_ids_round_trip() {
394        let ids = [
395            PackObjectId::Hash(ContentHash::compute(b"hash-object")),
396            PackObjectId::ChangeId(ChangeId::generate()),
397        ];
398
399        for id in ids {
400            let mut encoded = Vec::new();
401            id.encode_tagged(&mut encoded);
402            let (decoded, consumed) = PackObjectId::decode_tagged(&encoded).unwrap();
403            assert_eq!(decoded, id);
404            assert_eq!(consumed, encoded.len());
405        }
406    }
407
408    #[test]
409    fn tagged_entry_header_round_trips_mixed_identity() {
410        let record = PackObjectRecord {
411            id: PackObjectId::ChangeId(ChangeId::generate()),
412            obj_type: ObjectType::State,
413            data: vec![1, 2, 3, 4, 5],
414            delta_base: None,
415            path_hint: None,
416        };
417
418        let mut encoded = Vec::new();
419        encode_tagged_entry(&mut encoded, &record, record.obj_type, &record.data).unwrap();
420        let decoded = decode_tagged_entry_header(&encoded).unwrap();
421
422        assert_eq!(decoded.id, record.id);
423        assert_eq!(decoded.obj_type, ObjectType::State);
424        assert_eq!(decoded.uncompressed_size, 5);
425        assert_eq!(decoded.compressed_size, 5);
426        assert_eq!(decoded.delta_base, None);
427    }
428
429    #[test]
430    fn tagged_entry_header_rejects_size_that_truncates_on_32_bit() {
431        let mut encoded = Vec::new();
432        PackObjectId::Hash(ContentHash::compute(b"oversized-pack-object"))
433            .encode_tagged(&mut encoded);
434        varint::encode_type_and_size(ObjectType::Blob, u64::from(u32::MAX) + 1, &mut encoded);
435        varint::encode_varint(1, &mut encoded);
436        encoded.push(0);
437
438        let result = decode_tagged_entry_header(&encoded);
439
440        let error = result.expect_err("absurd 32-bit-overflow size must be rejected");
441        assert!(
442            matches!(&error, StoreError::InvalidObject(message) if message.contains("platform limits") || message.contains("Pack object output size")),
443            "expected size-limit InvalidObject, got: {error:?}",
444        );
445    }
446
447    #[test]
448    fn tagged_entry_header_rejects_u64_max_size_when_platform_cannot_represent_it() {
449        let mut encoded = Vec::new();
450        PackObjectId::Hash(ContentHash::compute(b"u64-max-pack-object"))
451            .encode_tagged(&mut encoded);
452        varint::encode_type_and_size(ObjectType::Blob, u64::MAX, &mut encoded);
453        varint::encode_varint(1, &mut encoded);
454        encoded.push(0);
455
456        let result = decode_tagged_entry_header(&encoded);
457
458        let error = result.expect_err("absurd u64::MAX size must be rejected");
459        assert!(
460            matches!(&error, StoreError::InvalidObject(message) if message.contains("platform limits") || message.contains("Pack object output size")),
461            "expected size-limit InvalidObject, got: {error:?}",
462        );
463    }
464}