1#![deny(clippy::cast_possible_truncation)]
3
4use super::{
5 ObjectType, varint,
6 versioned_header::{HeaderChecksum, VersionedHeader},
7};
8use crate::{
9 object::{ChangeId, ContentHash},
10 store::{Result, StoreError, compression::CompressionConfig},
11};
12
13pub const PACK_CHECKSUM_LEN: usize = 32;
14pub const MAX_PACK_OBJECT_OUTPUT_SIZE: usize = 1024 * 1024 * 1024;
15#[cfg(feature = "zstd")]
16pub(super) const PACK_DECOMPRESSION_INITIAL_CAP: usize = 4 * 1024 * 1024;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]
19pub enum PackObjectId {
20 Hash(ContentHash),
21 ChangeId(ChangeId),
22}
23
24impl PackObjectId {
25 pub fn encode_tagged(self, buf: &mut Vec<u8>) {
26 match self {
27 Self::Hash(hash) => {
28 buf.push(0);
29 buf.extend_from_slice(hash.as_bytes());
30 }
31 Self::ChangeId(change_id) => {
32 buf.push(1);
33 buf.extend_from_slice(change_id.as_bytes());
34 }
35 }
36 }
37
38 pub fn decode_tagged(data: &[u8]) -> Result<(Self, usize)> {
39 let Some(tag) = data.first().copied() else {
40 return Err(StoreError::InvalidObject(
41 "missing pack object id tag".to_string(),
42 ));
43 };
44 match tag {
45 0 => {
46 if data.len() < 33 {
47 return Err(StoreError::InvalidObject(
48 "hash pack object id truncated".to_string(),
49 ));
50 }
51 let hash = ContentHash::from_bytes(data[1..33].try_into().map_err(|_| {
52 StoreError::InvalidObject("invalid hash id length".to_string())
53 })?);
54 Ok((Self::Hash(hash), 33))
55 }
56 1 => {
57 if data.len() < 17 {
58 return Err(StoreError::InvalidObject(
59 "change id pack object id truncated".to_string(),
60 ));
61 }
62 let change_id = ChangeId::from_bytes(data[1..17].try_into().map_err(|_| {
63 StoreError::InvalidObject("invalid change id length".to_string())
64 })?);
65 Ok((Self::ChangeId(change_id), 17))
66 }
67 _ => Err(StoreError::InvalidObject(format!(
68 "unknown pack object id tag {tag}"
69 ))),
70 }
71 }
72}
73
74#[derive(Debug, Clone)]
75pub struct PackObjectRecord {
76 pub id: PackObjectId,
77 pub obj_type: ObjectType,
78 pub data: Vec<u8>,
79 pub delta_base: Option<PackObjectId>,
80 pub path_hint: Option<String>,
81}
82
83#[derive(Debug, Clone, Copy)]
84pub struct PackContainerSpec {
85 pub magic: &'static [u8; 4],
86 pub version: u32,
87}
88
89#[derive(Debug, Clone)]
90pub struct PackEntryHeader {
91 pub id: PackObjectId,
92 pub obj_type: ObjectType,
93 pub uncompressed_size: usize,
94 pub compressed_size: usize,
95 pub delta_base: Option<PackObjectId>,
96 pub header_len: usize,
97}
98
99pub fn write_container_header(buf: &mut Vec<u8>, spec: PackContainerSpec, count: u64) {
100 pack_container_header(spec).write_vec(buf, count);
101}
102
103pub fn verify_container(data: &[u8], spec: PackContainerSpec) -> Result<(u64, usize, usize)> {
104 let header = pack_container_header(spec).verify(data)?;
105 Ok((header.count, header.header_len, header.content_end))
106}
107
108pub fn append_container_checksum(buf: &mut Vec<u8>) {
109 HeaderChecksum::Blake3Trailer.append(buf);
110}
111
112fn pack_container_header(spec: PackContainerSpec) -> VersionedHeader {
113 VersionedHeader {
114 magic: spec.magic,
115 version: spec.version,
116 checksum: HeaderChecksum::Blake3Trailer,
117 too_short: "Pack too short",
118 invalid_magic: "Invalid pack magic",
119 unsupported_version: "Unsupported pack version",
120 checksum_mismatch: "Pack checksum mismatch",
121 }
122}
123
124pub fn encode_tagged_entry(
125 buf: &mut Vec<u8>,
126 record: &PackObjectRecord,
127 stored_type: ObjectType,
128 compressed: &[u8],
129) -> Result<()> {
130 encode_tagged_entry_parts(
131 buf,
132 record.id,
133 stored_type,
134 record.data.len(),
135 record.delta_base,
136 compressed,
137 )
138}
139
140pub fn encode_tagged_entry_parts(
141 buf: &mut Vec<u8>,
142 id: PackObjectId,
143 stored_type: ObjectType,
144 uncompressed_size: usize,
145 delta_base: Option<PackObjectId>,
146 compressed: &[u8],
147) -> Result<()> {
148 id.encode_tagged(buf);
149 varint::encode_type_and_size(stored_type, uncompressed_size as u64, buf);
150 varint::encode_varint(compressed.len() as u64, buf);
151 if stored_type == ObjectType::Delta {
152 let Some(base) = delta_base else {
153 return Err(StoreError::InvalidObject(
154 "Delta entry missing base id".to_string(),
155 ));
156 };
157 base.encode_tagged(buf);
158 }
159 buf.extend_from_slice(compressed);
160 Ok(())
161}
162
163pub fn decode_tagged_entry_header(data: &[u8]) -> Result<PackEntryHeader> {
164 let (id, id_len) = PackObjectId::decode_tagged(data)?;
165 let (obj_type, uncompressed_size, type_len) = varint::decode_type_and_size(&data[id_len..])
166 .ok_or_else(|| StoreError::InvalidObject("Truncated type+size varint".to_string()))?;
167 let varint_start = id_len + type_len;
168 let (compressed_size, comp_len) = varint::decode_varint(&data[varint_start..])
169 .ok_or_else(|| StoreError::InvalidObject("Truncated compressed_size varint".to_string()))?;
170 let mut header_len = varint_start + comp_len;
171
172 let delta_base = if obj_type == ObjectType::Delta {
173 let (base, base_len) = PackObjectId::decode_tagged(&data[header_len..])?;
174 header_len += base_len;
175 Some(base)
176 } else {
177 None
178 };
179
180 Ok(PackEntryHeader {
181 id,
182 obj_type,
183 uncompressed_size: checked_decoded_size("uncompressed_size", uncompressed_size)?,
184 compressed_size: checked_decoded_size("compressed_size", compressed_size)?,
185 delta_base,
186 header_len,
187 })
188}
189
190pub fn try_decode_tagged_entry_header(data: &[u8]) -> Result<Option<PackEntryHeader>> {
191 let Some(tag) = data.first().copied() else {
192 return Ok(None);
193 };
194
195 let (id, id_len) =
196 match tag {
197 0 => {
198 if data.len() < 33 {
199 return Ok(None);
200 }
201 let hash = ContentHash::from_bytes(data[1..33].try_into().map_err(|_| {
202 StoreError::InvalidObject("invalid hash id length".to_string())
203 })?);
204 (PackObjectId::Hash(hash), 33)
205 }
206 1 => {
207 if data.len() < 17 {
208 return Ok(None);
209 }
210 let change_id = ChangeId::from_bytes(data[1..17].try_into().map_err(|_| {
211 StoreError::InvalidObject("invalid change id length".to_string())
212 })?);
213 (PackObjectId::ChangeId(change_id), 17)
214 }
215 _ => {
216 return Err(StoreError::InvalidObject(format!(
217 "unknown pack object id tag {tag}"
218 )));
219 }
220 };
221
222 let Some((obj_type, uncompressed_size, type_len)) =
223 varint::decode_type_and_size(&data[id_len..])
224 else {
225 return Ok(None);
226 };
227 let varint_start = id_len + type_len;
228 let Some((compressed_size, comp_len)) = varint::decode_varint(&data[varint_start..]) else {
229 return Ok(None);
230 };
231 let mut header_len = varint_start + comp_len;
232
233 let delta_base = if obj_type == ObjectType::Delta {
234 let Some(base_tag) = data.get(header_len).copied() else {
235 return Ok(None);
236 };
237 let (base, base_len) = match base_tag {
238 0 => {
239 let end = header_len + 33;
240 if data.len() < end {
241 return Ok(None);
242 }
243 let hash = ContentHash::from_bytes(data[header_len + 1..end].try_into().map_err(
244 |_| StoreError::InvalidObject("invalid hash id length".to_string()),
245 )?);
246 (PackObjectId::Hash(hash), 33)
247 }
248 1 => {
249 let end = header_len + 17;
250 if data.len() < end {
251 return Ok(None);
252 }
253 let change_id =
254 ChangeId::from_bytes(data[header_len + 1..end].try_into().map_err(|_| {
255 StoreError::InvalidObject("invalid change id length".to_string())
256 })?);
257 (PackObjectId::ChangeId(change_id), 17)
258 }
259 _ => {
260 return Err(StoreError::InvalidObject(format!(
261 "unknown pack object id tag {base_tag}"
262 )));
263 }
264 };
265 header_len += base_len;
266 Some(base)
267 } else {
268 None
269 };
270
271 Ok(Some(PackEntryHeader {
272 id,
273 obj_type,
274 uncompressed_size: checked_decoded_size("uncompressed_size", uncompressed_size)?,
275 compressed_size: checked_decoded_size("compressed_size", compressed_size)?,
276 delta_base,
277 header_len,
278 }))
279}
280
281fn checked_decoded_size(field: &str, size: u64) -> Result<usize> {
282 let size = usize::try_from(size).map_err(|_| {
283 StoreError::InvalidObject(format!("Decoded {field} exceeds platform limits"))
284 })?;
285 if field == "uncompressed_size" {
286 reject_pack_object_output_over_limit(size, MAX_PACK_OBJECT_OUTPUT_SIZE)?;
287 }
288 Ok(size)
289}
290
291pub fn compress_pack_payload(data: &[u8], config: &CompressionConfig) -> Result<Vec<u8>> {
292 if !config.enabled || data.len() < config.min_size {
293 return Ok(data.to_vec());
294 }
295 #[cfg(feature = "zstd")]
296 {
297 match zstd::encode_all(data, config.level) {
298 Ok(compressed) if compressed.len() < data.len() => Ok(compressed),
299 _ => Ok(data.to_vec()),
300 }
301 }
302 #[cfg(not(feature = "zstd"))]
303 {
304 let _ = config;
305 Ok(data.to_vec())
306 }
307}
308
309pub fn decompress_pack_payload(data: &[u8], expected_size: usize) -> Result<Vec<u8>> {
310 #[cfg(feature = "zstd")]
311 {
312 decompress_pack_payload_with_limit(data, expected_size, MAX_PACK_OBJECT_OUTPUT_SIZE)
313 }
314 #[cfg(not(feature = "zstd"))]
315 {
316 reject_pack_object_output_over_limit(expected_size, MAX_PACK_OBJECT_OUTPUT_SIZE)?;
317 reject_pack_object_output_over_limit(data.len(), MAX_PACK_OBJECT_OUTPUT_SIZE)?;
318 Ok(data.to_vec())
319 }
320}
321
322#[cfg(feature = "zstd")]
323pub(super) fn decompress_pack_payload_with_limit(
324 data: &[u8],
325 expected_size: usize,
326 max_output_size: usize,
327) -> Result<Vec<u8>> {
328 use std::io::Read;
329
330 reject_pack_object_output_over_limit(expected_size, max_output_size)?;
335
336 let mut decoder = zstd::stream::read::Decoder::new(data)
337 .map_err(|e| StoreError::InvalidObject(format!("zstd decode init failed: {e}")))?;
338 let capacity = initial_decompression_capacity(data.len(), expected_size, max_output_size);
339 let mut buf = Vec::with_capacity(capacity);
340 let mut chunk = [0u8; 8192];
341
342 loop {
343 let bytes_read = decoder
344 .read(&mut chunk)
345 .map_err(|e| StoreError::InvalidObject(format!("zstd decompression failed: {e}")))?;
346 if bytes_read == 0 {
347 break;
348 }
349
350 let next_len = buf.len().checked_add(bytes_read).ok_or_else(|| {
351 StoreError::InvalidObject("Pack object output size overflows".to_string())
352 })?;
353 reject_pack_object_output_over_limit(next_len, max_output_size)?;
354 buf.extend_from_slice(&chunk[..bytes_read]);
355 }
356
357 Ok(buf)
358}
359
360#[cfg(feature = "zstd")]
361fn initial_decompression_capacity(
362 compressed_len: usize,
363 expected_size: usize,
364 max_output_size: usize,
365) -> usize {
366 let hint = if expected_size > 0 {
367 expected_size
368 } else {
369 compressed_len.saturating_mul(2)
370 };
371 hint.min(PACK_DECOMPRESSION_INITIAL_CAP)
372 .min(max_output_size)
373}
374
375fn reject_pack_object_output_over_limit(size: usize, max: usize) -> Result<()> {
376 if size > max {
377 return Err(StoreError::InvalidObject(format!(
378 "Pack object output size {size} exceeds max {max}"
379 )));
380 }
381 Ok(())
382}
383
384pub fn has_zstd_magic(data: &[u8]) -> bool {
385 data.len() >= 4 && data[..4] == [0x28, 0xB5, 0x2F, 0xFD]
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[test]
393 fn tagged_pack_object_ids_round_trip() {
394 let ids = [
395 PackObjectId::Hash(ContentHash::compute(b"hash-object")),
396 PackObjectId::ChangeId(ChangeId::generate()),
397 ];
398
399 for id in ids {
400 let mut encoded = Vec::new();
401 id.encode_tagged(&mut encoded);
402 let (decoded, consumed) = PackObjectId::decode_tagged(&encoded).unwrap();
403 assert_eq!(decoded, id);
404 assert_eq!(consumed, encoded.len());
405 }
406 }
407
408 #[test]
409 fn tagged_entry_header_round_trips_mixed_identity() {
410 let record = PackObjectRecord {
411 id: PackObjectId::ChangeId(ChangeId::generate()),
412 obj_type: ObjectType::State,
413 data: vec![1, 2, 3, 4, 5],
414 delta_base: None,
415 path_hint: None,
416 };
417
418 let mut encoded = Vec::new();
419 encode_tagged_entry(&mut encoded, &record, record.obj_type, &record.data).unwrap();
420 let decoded = decode_tagged_entry_header(&encoded).unwrap();
421
422 assert_eq!(decoded.id, record.id);
423 assert_eq!(decoded.obj_type, ObjectType::State);
424 assert_eq!(decoded.uncompressed_size, 5);
425 assert_eq!(decoded.compressed_size, 5);
426 assert_eq!(decoded.delta_base, None);
427 }
428
429 #[test]
430 fn tagged_entry_header_rejects_size_that_truncates_on_32_bit() {
431 let mut encoded = Vec::new();
432 PackObjectId::Hash(ContentHash::compute(b"oversized-pack-object"))
433 .encode_tagged(&mut encoded);
434 varint::encode_type_and_size(ObjectType::Blob, u64::from(u32::MAX) + 1, &mut encoded);
435 varint::encode_varint(1, &mut encoded);
436 encoded.push(0);
437
438 let result = decode_tagged_entry_header(&encoded);
439
440 let error = result.expect_err("absurd 32-bit-overflow size must be rejected");
441 assert!(
442 matches!(&error, StoreError::InvalidObject(message) if message.contains("platform limits") || message.contains("Pack object output size")),
443 "expected size-limit InvalidObject, got: {error:?}",
444 );
445 }
446
447 #[test]
448 fn tagged_entry_header_rejects_u64_max_size_when_platform_cannot_represent_it() {
449 let mut encoded = Vec::new();
450 PackObjectId::Hash(ContentHash::compute(b"u64-max-pack-object"))
451 .encode_tagged(&mut encoded);
452 varint::encode_type_and_size(ObjectType::Blob, u64::MAX, &mut encoded);
453 varint::encode_varint(1, &mut encoded);
454 encoded.push(0);
455
456 let result = decode_tagged_entry_header(&encoded);
457
458 let error = result.expect_err("absurd u64::MAX size must be rejected");
459 assert!(
460 matches!(&error, StoreError::InvalidObject(message) if message.contains("platform limits") || message.contains("Pack object output size")),
461 "expected size-limit InvalidObject, got: {error:?}",
462 );
463 }
464}