1#![deny(clippy::cast_possible_truncation)]
3
4use heddle_format::compression::CompressionConfig;
5
6use super::{
7 ObjectType, varint,
8 versioned_header::{HeaderChecksum, VersionedHeader},
9};
10use crate::{
11 object::{ChangeId, ContentHash},
12 store::{Result, StoreError},
13};
14
15pub const PACK_CHECKSUM_LEN: usize = 32;
16pub const MAX_PACK_OBJECT_OUTPUT_SIZE: usize = 1024 * 1024 * 1024;
17#[cfg(feature = "zstd")]
18pub(super) const PACK_DECOMPRESSION_INITIAL_CAP: usize = 4 * 1024 * 1024;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]
21pub enum PackObjectId {
22 Hash(ContentHash),
23 ChangeId(ChangeId),
24}
25
26impl PackObjectId {
27 pub fn encode_tagged(self, buf: &mut Vec<u8>) {
28 match self {
29 Self::Hash(hash) => {
30 buf.push(0);
31 buf.extend_from_slice(hash.as_bytes());
32 }
33 Self::ChangeId(change_id) => {
34 buf.push(1);
35 buf.extend_from_slice(change_id.as_bytes());
36 }
37 }
38 }
39
40 pub fn decode_tagged(data: &[u8]) -> Result<(Self, usize)> {
41 let Some(tag) = data.first().copied() else {
42 return Err(StoreError::InvalidObject(
43 "missing pack object id tag".to_string(),
44 ));
45 };
46 match tag {
47 0 => {
48 if data.len() < 33 {
49 return Err(StoreError::InvalidObject(
50 "hash pack object id truncated".to_string(),
51 ));
52 }
53 let hash = ContentHash::from_bytes(data[1..33].try_into().map_err(|_| {
54 StoreError::InvalidObject("invalid hash id length".to_string())
55 })?);
56 Ok((Self::Hash(hash), 33))
57 }
58 1 => {
59 if data.len() < 17 {
60 return Err(StoreError::InvalidObject(
61 "change id pack object id truncated".to_string(),
62 ));
63 }
64 let change_id = ChangeId::from_bytes(data[1..17].try_into().map_err(|_| {
65 StoreError::InvalidObject("invalid change id length".to_string())
66 })?);
67 Ok((Self::ChangeId(change_id), 17))
68 }
69 _ => Err(StoreError::InvalidObject(format!(
70 "unknown pack object id tag {tag}"
71 ))),
72 }
73 }
74}
75
76#[derive(Debug, Clone)]
77pub struct PackObjectRecord {
78 pub id: PackObjectId,
79 pub obj_type: ObjectType,
80 pub data: Vec<u8>,
81 pub delta_base: Option<PackObjectId>,
82 pub path_hint: Option<String>,
83}
84
85#[derive(Debug, Clone, Copy)]
86pub struct PackContainerSpec {
87 pub magic: &'static [u8; 4],
88 pub version: u32,
89}
90
91#[derive(Debug, Clone)]
92pub struct PackEntryHeader {
93 pub id: PackObjectId,
94 pub obj_type: ObjectType,
95 pub uncompressed_size: usize,
96 pub compressed_size: usize,
97 pub delta_base: Option<PackObjectId>,
98 pub header_len: usize,
99}
100
101pub fn write_container_header(buf: &mut Vec<u8>, spec: PackContainerSpec, count: u64) {
102 pack_container_header(spec).write_vec(buf, count);
103}
104
105pub fn verify_container(data: &[u8], spec: PackContainerSpec) -> Result<(u64, usize, usize)> {
106 let header = pack_container_header(spec).verify(data)?;
107 Ok((header.count, header.header_len, header.content_end))
108}
109
110pub fn append_container_checksum(buf: &mut Vec<u8>) {
111 HeaderChecksum::Blake3Trailer.append(buf);
112}
113
114fn pack_container_header(spec: PackContainerSpec) -> VersionedHeader {
115 VersionedHeader {
116 magic: spec.magic,
117 version: spec.version,
118 checksum: HeaderChecksum::Blake3Trailer,
119 too_short: "Pack too short",
120 invalid_magic: "Invalid pack magic",
121 unsupported_version: "Unsupported pack version",
122 checksum_mismatch: "Pack checksum mismatch",
123 }
124}
125
126pub fn encode_tagged_entry(
127 buf: &mut Vec<u8>,
128 record: &PackObjectRecord,
129 stored_type: ObjectType,
130 compressed: &[u8],
131) -> Result<()> {
132 encode_tagged_entry_parts(
133 buf,
134 record.id,
135 stored_type,
136 record.data.len(),
137 record.delta_base,
138 compressed,
139 )
140}
141
142pub fn encode_tagged_entry_parts(
143 buf: &mut Vec<u8>,
144 id: PackObjectId,
145 stored_type: ObjectType,
146 uncompressed_size: usize,
147 delta_base: Option<PackObjectId>,
148 compressed: &[u8],
149) -> Result<()> {
150 id.encode_tagged(buf);
151 varint::encode_type_and_size(stored_type, uncompressed_size as u64, buf);
152 varint::encode_varint(compressed.len() as u64, buf);
153 if stored_type == ObjectType::Delta {
154 let Some(base) = delta_base else {
155 return Err(StoreError::InvalidObject(
156 "Delta entry missing base id".to_string(),
157 ));
158 };
159 base.encode_tagged(buf);
160 }
161 buf.extend_from_slice(compressed);
162 Ok(())
163}
164
165pub fn decode_tagged_entry_header(data: &[u8]) -> Result<PackEntryHeader> {
166 let (id, id_len) = PackObjectId::decode_tagged(data)?;
167 let (obj_type, uncompressed_size, type_len) = varint::decode_type_and_size(&data[id_len..])
168 .ok_or_else(|| StoreError::InvalidObject("Truncated type+size varint".to_string()))?;
169 let varint_start = id_len + type_len;
170 let (compressed_size, comp_len) = varint::decode_varint(&data[varint_start..])
171 .ok_or_else(|| StoreError::InvalidObject("Truncated compressed_size varint".to_string()))?;
172 let mut header_len = varint_start + comp_len;
173
174 let delta_base = if obj_type == ObjectType::Delta {
175 let (base, base_len) = PackObjectId::decode_tagged(&data[header_len..])?;
176 header_len += base_len;
177 Some(base)
178 } else {
179 None
180 };
181
182 Ok(PackEntryHeader {
183 id,
184 obj_type,
185 uncompressed_size: checked_decoded_size("uncompressed_size", uncompressed_size)?,
186 compressed_size: checked_decoded_size("compressed_size", compressed_size)?,
187 delta_base,
188 header_len,
189 })
190}
191
192pub fn try_decode_tagged_entry_header(data: &[u8]) -> Result<Option<PackEntryHeader>> {
193 let Some(tag) = data.first().copied() else {
194 return Ok(None);
195 };
196
197 let (id, id_len) =
198 match tag {
199 0 => {
200 if data.len() < 33 {
201 return Ok(None);
202 }
203 let hash = ContentHash::from_bytes(data[1..33].try_into().map_err(|_| {
204 StoreError::InvalidObject("invalid hash id length".to_string())
205 })?);
206 (PackObjectId::Hash(hash), 33)
207 }
208 1 => {
209 if data.len() < 17 {
210 return Ok(None);
211 }
212 let change_id = ChangeId::from_bytes(data[1..17].try_into().map_err(|_| {
213 StoreError::InvalidObject("invalid change id length".to_string())
214 })?);
215 (PackObjectId::ChangeId(change_id), 17)
216 }
217 _ => {
218 return Err(StoreError::InvalidObject(format!(
219 "unknown pack object id tag {tag}"
220 )));
221 }
222 };
223
224 let Some((obj_type, uncompressed_size, type_len)) =
225 varint::decode_type_and_size(&data[id_len..])
226 else {
227 return Ok(None);
228 };
229 let varint_start = id_len + type_len;
230 let Some((compressed_size, comp_len)) = varint::decode_varint(&data[varint_start..]) else {
231 return Ok(None);
232 };
233 let mut header_len = varint_start + comp_len;
234
235 let delta_base = if obj_type == ObjectType::Delta {
236 let Some(base_tag) = data.get(header_len).copied() else {
237 return Ok(None);
238 };
239 let (base, base_len) = match base_tag {
240 0 => {
241 let end = header_len + 33;
242 if data.len() < end {
243 return Ok(None);
244 }
245 let hash = ContentHash::from_bytes(data[header_len + 1..end].try_into().map_err(
246 |_| StoreError::InvalidObject("invalid hash id length".to_string()),
247 )?);
248 (PackObjectId::Hash(hash), 33)
249 }
250 1 => {
251 let end = header_len + 17;
252 if data.len() < end {
253 return Ok(None);
254 }
255 let change_id =
256 ChangeId::from_bytes(data[header_len + 1..end].try_into().map_err(|_| {
257 StoreError::InvalidObject("invalid change id length".to_string())
258 })?);
259 (PackObjectId::ChangeId(change_id), 17)
260 }
261 _ => {
262 return Err(StoreError::InvalidObject(format!(
263 "unknown pack object id tag {base_tag}"
264 )));
265 }
266 };
267 header_len += base_len;
268 Some(base)
269 } else {
270 None
271 };
272
273 Ok(Some(PackEntryHeader {
274 id,
275 obj_type,
276 uncompressed_size: checked_decoded_size("uncompressed_size", uncompressed_size)?,
277 compressed_size: checked_decoded_size("compressed_size", compressed_size)?,
278 delta_base,
279 header_len,
280 }))
281}
282
283fn checked_decoded_size(field: &str, size: u64) -> Result<usize> {
284 let size = usize::try_from(size).map_err(|_| {
285 StoreError::InvalidObject(format!("Decoded {field} exceeds platform limits"))
286 })?;
287 if field == "uncompressed_size" {
288 reject_pack_object_output_over_limit(size, MAX_PACK_OBJECT_OUTPUT_SIZE)?;
289 }
290 Ok(size)
291}
292
293pub fn compress_pack_payload(data: &[u8], config: &CompressionConfig) -> Result<Vec<u8>> {
294 if !config.enabled || data.len() < config.min_size {
295 return Ok(data.to_vec());
296 }
297 #[cfg(feature = "zstd")]
298 {
299 match zstd::encode_all(data, config.level) {
300 Ok(compressed) if compressed.len() < data.len() => Ok(compressed),
301 _ => Ok(data.to_vec()),
302 }
303 }
304 #[cfg(not(feature = "zstd"))]
305 {
306 let _ = config;
307 Ok(data.to_vec())
308 }
309}
310
311pub fn decompress_pack_payload(data: &[u8], expected_size: usize) -> Result<Vec<u8>> {
312 #[cfg(feature = "zstd")]
313 {
314 decompress_pack_payload_with_limit(data, expected_size, MAX_PACK_OBJECT_OUTPUT_SIZE)
315 }
316 #[cfg(not(feature = "zstd"))]
317 {
318 reject_pack_object_output_over_limit(expected_size, MAX_PACK_OBJECT_OUTPUT_SIZE)?;
319 reject_pack_object_output_over_limit(data.len(), MAX_PACK_OBJECT_OUTPUT_SIZE)?;
320 Ok(data.to_vec())
321 }
322}
323
324#[cfg(feature = "zstd")]
325pub(super) fn decompress_pack_payload_with_limit(
326 data: &[u8],
327 expected_size: usize,
328 max_output_size: usize,
329) -> Result<Vec<u8>> {
330 use std::io::Read;
331
332 reject_pack_object_output_over_limit(expected_size, max_output_size)?;
337
338 let mut decoder = zstd::stream::read::Decoder::new(data)
339 .map_err(|e| StoreError::InvalidObject(format!("zstd decode init failed: {e}")))?;
340 let capacity = initial_decompression_capacity(data.len(), expected_size, max_output_size);
341 let mut buf = Vec::with_capacity(capacity);
342 let mut chunk = [0u8; 8192];
343
344 loop {
345 let bytes_read = decoder
346 .read(&mut chunk)
347 .map_err(|e| StoreError::InvalidObject(format!("zstd decompression failed: {e}")))?;
348 if bytes_read == 0 {
349 break;
350 }
351
352 let next_len = buf.len().checked_add(bytes_read).ok_or_else(|| {
353 StoreError::InvalidObject("Pack object output size overflows".to_string())
354 })?;
355 reject_pack_object_output_over_limit(next_len, max_output_size)?;
356 buf.extend_from_slice(&chunk[..bytes_read]);
357 }
358
359 Ok(buf)
360}
361
362#[cfg(feature = "zstd")]
363fn initial_decompression_capacity(
364 compressed_len: usize,
365 expected_size: usize,
366 max_output_size: usize,
367) -> usize {
368 let hint = if expected_size > 0 {
369 expected_size
370 } else {
371 compressed_len.saturating_mul(2)
372 };
373 hint.min(PACK_DECOMPRESSION_INITIAL_CAP)
374 .min(max_output_size)
375}
376
377fn reject_pack_object_output_over_limit(size: usize, max: usize) -> Result<()> {
378 if size > max {
379 return Err(StoreError::InvalidObject(format!(
380 "Pack object output size {size} exceeds max {max}"
381 )));
382 }
383 Ok(())
384}
385
386pub fn has_zstd_magic(data: &[u8]) -> bool {
387 data.len() >= 4 && data[..4] == [0x28, 0xB5, 0x2F, 0xFD]
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393
394 #[test]
395 fn tagged_pack_object_ids_round_trip() {
396 let ids = [
397 PackObjectId::Hash(ContentHash::compute(b"hash-object")),
398 PackObjectId::ChangeId(ChangeId::generate()),
399 ];
400
401 for id in ids {
402 let mut encoded = Vec::new();
403 id.encode_tagged(&mut encoded);
404 let (decoded, consumed) = PackObjectId::decode_tagged(&encoded).unwrap();
405 assert_eq!(decoded, id);
406 assert_eq!(consumed, encoded.len());
407 }
408 }
409
410 #[test]
411 fn tagged_entry_header_round_trips_mixed_identity() {
412 let record = PackObjectRecord {
413 id: PackObjectId::ChangeId(ChangeId::generate()),
414 obj_type: ObjectType::State,
415 data: vec![1, 2, 3, 4, 5],
416 delta_base: None,
417 path_hint: None,
418 };
419
420 let mut encoded = Vec::new();
421 encode_tagged_entry(&mut encoded, &record, record.obj_type, &record.data).unwrap();
422 let decoded = decode_tagged_entry_header(&encoded).unwrap();
423
424 assert_eq!(decoded.id, record.id);
425 assert_eq!(decoded.obj_type, ObjectType::State);
426 assert_eq!(decoded.uncompressed_size, 5);
427 assert_eq!(decoded.compressed_size, 5);
428 assert_eq!(decoded.delta_base, None);
429 }
430
431 #[test]
432 fn tagged_entry_header_rejects_size_that_truncates_on_32_bit() {
433 let mut encoded = Vec::new();
434 PackObjectId::Hash(ContentHash::compute(b"oversized-pack-object"))
435 .encode_tagged(&mut encoded);
436 varint::encode_type_and_size(ObjectType::Blob, u64::from(u32::MAX) + 1, &mut encoded);
437 varint::encode_varint(1, &mut encoded);
438 encoded.push(0);
439
440 let result = decode_tagged_entry_header(&encoded);
441
442 let error = result.expect_err("absurd 32-bit-overflow size must be rejected");
443 assert!(
444 matches!(&error, StoreError::InvalidObject(message) if message.contains("platform limits") || message.contains("Pack object output size")),
445 "expected size-limit InvalidObject, got: {error:?}",
446 );
447 }
448
449 #[test]
450 fn tagged_entry_header_rejects_u64_max_size_when_platform_cannot_represent_it() {
451 let mut encoded = Vec::new();
452 PackObjectId::Hash(ContentHash::compute(b"u64-max-pack-object"))
453 .encode_tagged(&mut encoded);
454 varint::encode_type_and_size(ObjectType::Blob, u64::MAX, &mut encoded);
455 varint::encode_varint(1, &mut encoded);
456 encoded.push(0);
457
458 let result = decode_tagged_entry_header(&encoded);
459
460 let error = result.expect_err("absurd u64::MAX size must be rejected");
461 assert!(
462 matches!(&error, StoreError::InvalidObject(message) if message.contains("platform limits") || message.contains("Pack object output size")),
463 "expected size-limit InvalidObject, got: {error:?}",
464 );
465 }
466}