1use std::fs::File;
38use std::io::{Read, Seek, SeekFrom, Write};
39use std::path::Path;
40
41use haagenti_core::{Compressor, Decompressor, Error, Result};
42use xxhash_rust::xxh3::xxh3_64;
43
44pub const HCT_MAGIC: [u8; 4] = *b"HCTN";
46
47pub const HCT_VERSION: u32 = 1;
49
50pub const HCT_VERSION_V2: u32 = 2;
52
53pub const FLAG_HEADER_CHECKSUM: u16 = 0x0001;
57
58pub const FLAG_BLOCK_CHECKSUMS: u16 = 0x0002;
60
61pub const FLAG_QUANTIZATION: u16 = 0x0004;
63
64pub const FLAG_TENSOR_NAME: u16 = 0x0008;
66
67pub const FLAG_HOLOGRAPHIC: u16 = 0x0010;
71
72pub const DEFAULT_BLOCK_SIZE: u32 = 16 * 1024;
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78#[repr(u8)]
79pub enum CompressionAlgorithm {
80 Lz4 = 0,
81 Zstd = 1,
82}
83
84impl TryFrom<u8> for CompressionAlgorithm {
85 type Error = Error;
86
87 fn try_from(value: u8) -> Result<Self> {
88 match value {
89 0 => Ok(CompressionAlgorithm::Lz4),
90 1 => Ok(CompressionAlgorithm::Zstd),
91 _ => Err(Error::corrupted(format!("unknown algorithm: {}", value))),
92 }
93 }
94}
95
96#[derive(Debug, Clone, Copy, PartialEq, Eq)]
98#[repr(u8)]
99pub enum DType {
100 F32 = 0,
101 F16 = 1,
102 BF16 = 2,
103 I8 = 3,
104 I4 = 4,
105}
106
107impl DType {
108 pub fn bits(&self) -> usize {
110 match self {
111 DType::F32 => 32,
112 DType::F16 | DType::BF16 => 16,
113 DType::I8 => 8,
114 DType::I4 => 4,
115 }
116 }
117
118 pub fn bytes(&self) -> usize {
120 self.bits().div_ceil(8)
121 }
122}
123
124impl TryFrom<u8> for DType {
125 type Error = Error;
126
127 fn try_from(value: u8) -> Result<Self> {
128 match value {
129 0 => Ok(DType::F32),
130 1 => Ok(DType::F16),
131 2 => Ok(DType::BF16),
132 3 => Ok(DType::I8),
133 4 => Ok(DType::I4),
134 _ => Err(Error::corrupted(format!("unknown dtype: {}", value))),
135 }
136 }
137}
138
139#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
143#[repr(u8)]
144pub enum QuantizationScheme {
145 #[default]
147 None = 0,
148 GptqInt4 = 1,
150 AwqInt4 = 2,
152 SymmetricInt8 = 3,
154 AsymmetricInt8 = 4,
156}
157
158impl TryFrom<u8> for QuantizationScheme {
159 type Error = Error;
160
161 fn try_from(value: u8) -> Result<Self> {
162 match value {
163 0 => Ok(QuantizationScheme::None),
164 1 => Ok(QuantizationScheme::GptqInt4),
165 2 => Ok(QuantizationScheme::AwqInt4),
166 3 => Ok(QuantizationScheme::SymmetricInt8),
167 4 => Ok(QuantizationScheme::AsymmetricInt8),
168 _ => Err(Error::corrupted(format!(
169 "unknown quantization scheme: {}",
170 value
171 ))),
172 }
173 }
174}
175
176#[derive(Debug, Clone, Default, PartialEq)]
180pub struct QuantizationMetadata {
181 pub scheme: QuantizationScheme,
183 pub group_size: u32,
185 pub scale_bits: u16,
187 pub zero_point: i8,
189 pub has_per_group_scales: bool,
191}
192
193impl QuantizationMetadata {
194 pub const SIZE: usize = 8;
196
197 pub fn to_bytes(&self) -> [u8; Self::SIZE] {
199 let mut buf = [0u8; Self::SIZE];
200 buf[0] = self.scheme as u8;
201 buf[1] = if self.has_per_group_scales { 1 } else { 0 };
202 buf[2..4].copy_from_slice(&self.scale_bits.to_le_bytes());
203 buf[4] = self.zero_point as u8;
204 buf[5..8].copy_from_slice(&self.group_size.to_le_bytes()[..3]);
205 buf
206 }
207
208 pub fn from_bytes(buf: &[u8; Self::SIZE]) -> Result<Self> {
210 let scheme = QuantizationScheme::try_from(buf[0])?;
211 let has_per_group_scales = buf[1] != 0;
212 let scale_bits = u16::from_le_bytes([buf[2], buf[3]]);
213 let zero_point = buf[4] as i8;
214 let mut group_size_bytes = [0u8; 4];
215 group_size_bytes[..3].copy_from_slice(&buf[5..8]);
216 let group_size = u32::from_le_bytes(group_size_bytes);
217
218 Ok(Self {
219 scheme,
220 group_size,
221 scale_bits,
222 zero_point,
223 has_per_group_scales,
224 })
225 }
226}
227
228#[derive(Debug, Clone, Copy)]
232pub struct BlockIndexV2 {
233 pub offset: u32,
235 pub compressed_size: u32,
237 pub checksum: u64,
239}
240
241impl BlockIndexV2 {
242 pub const SIZE: usize = 16;
244
245 pub fn to_bytes(&self) -> [u8; Self::SIZE] {
247 let mut buf = [0u8; Self::SIZE];
248 buf[0..4].copy_from_slice(&self.offset.to_le_bytes());
249 buf[4..8].copy_from_slice(&self.compressed_size.to_le_bytes());
250 buf[8..16].copy_from_slice(&self.checksum.to_le_bytes());
251 buf
252 }
253
254 pub fn from_bytes(buf: &[u8; Self::SIZE]) -> Self {
256 Self {
257 offset: u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]),
258 compressed_size: u32::from_le_bytes([buf[4], buf[5], buf[6], buf[7]]),
259 checksum: u64::from_le_bytes(buf[8..16].try_into().unwrap()),
260 }
261 }
262
263 pub fn from_v1(v1: BlockIndex) -> Self {
265 Self {
266 offset: v1.offset,
267 compressed_size: v1.compressed_size,
268 checksum: 0,
269 }
270 }
271}
272
273#[derive(Debug, Clone)]
275pub struct HctHeader {
276 pub algorithm: CompressionAlgorithm,
278 pub dtype: DType,
280 pub flags: u16,
282 pub original_size: u64,
284 pub compressed_size: u64,
286 pub block_size: u32,
288 pub num_blocks: u32,
290 pub shape: Vec<u64>,
292}
293
294impl HctHeader {
295 pub const SIZE: usize = 64;
297
298 pub fn to_bytes(&self) -> [u8; Self::SIZE] {
300 let mut buf = [0u8; Self::SIZE];
301
302 buf[0..4].copy_from_slice(&HCT_MAGIC);
304
305 buf[4..8].copy_from_slice(&HCT_VERSION.to_le_bytes());
307
308 buf[8] = self.algorithm as u8;
310 buf[9] = self.dtype as u8;
311
312 buf[10..12].copy_from_slice(&self.flags.to_le_bytes());
314
315 buf[12..20].copy_from_slice(&self.original_size.to_le_bytes());
317 buf[20..28].copy_from_slice(&self.compressed_size.to_le_bytes());
318 buf[28..32].copy_from_slice(&self.block_size.to_le_bytes());
319 buf[32..36].copy_from_slice(&self.num_blocks.to_le_bytes());
320
321 buf[36] = self.shape.len() as u8;
323 for (i, &dim) in self.shape.iter().take(4).enumerate() {
324 let offset = 37 + i * 8;
325 buf[offset..offset + 8].copy_from_slice(&dim.to_le_bytes());
326 }
327
328 buf
329 }
330
331 pub fn from_bytes(buf: &[u8; Self::SIZE]) -> Result<Self> {
333 if buf[0..4] != HCT_MAGIC {
335 return Err(Error::corrupted("invalid HCT magic"));
336 }
337
338 let version = u32::from_le_bytes([buf[4], buf[5], buf[6], buf[7]]);
340 if version > HCT_VERSION_V2 {
341 return Err(Error::corrupted(format!(
342 "unsupported HCT version: {} (max: {})",
343 version, HCT_VERSION_V2
344 )));
345 }
346
347 let algorithm = CompressionAlgorithm::try_from(buf[8])?;
348 let dtype = DType::try_from(buf[9])?;
349 let flags = u16::from_le_bytes([buf[10], buf[11]]);
350
351 let original_size = u64::from_le_bytes(buf[12..20].try_into().unwrap());
352 let compressed_size = u64::from_le_bytes(buf[20..28].try_into().unwrap());
353 let block_size = u32::from_le_bytes(buf[28..32].try_into().unwrap());
354 let num_blocks = u32::from_le_bytes(buf[32..36].try_into().unwrap());
355
356 let rank = buf[36] as usize;
357 let mut shape = Vec::with_capacity(rank);
358 for i in 0..rank.min(4) {
359 let offset = 37 + i * 8;
360 let dim = u64::from_le_bytes(buf[offset..offset + 8].try_into().unwrap());
361 shape.push(dim);
362 }
363
364 Ok(Self {
365 algorithm,
366 dtype,
367 flags,
368 original_size,
369 compressed_size,
370 block_size,
371 num_blocks,
372 shape,
373 })
374 }
375}
376
377#[derive(Debug, Clone, Copy)]
379pub struct BlockIndex {
380 pub offset: u32,
382 pub compressed_size: u32,
384}
385
386impl BlockIndex {
387 pub const SIZE: usize = 8;
389
390 pub fn to_bytes(&self) -> [u8; Self::SIZE] {
392 let mut buf = [0u8; Self::SIZE];
393 buf[0..4].copy_from_slice(&self.offset.to_le_bytes());
394 buf[4..8].copy_from_slice(&self.compressed_size.to_le_bytes());
395 buf
396 }
397
398 pub fn from_bytes(buf: &[u8; Self::SIZE]) -> Self {
400 Self {
401 offset: u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]),
402 compressed_size: u32::from_le_bytes([buf[4], buf[5], buf[6], buf[7]]),
403 }
404 }
405}
406
407pub struct HctReader<R: Read + Seek> {
409 reader: R,
410 header: HctHeader,
411 block_index: Vec<BlockIndex>,
412 data_offset: u64,
413}
414
415impl<R: Read + Seek> HctReader<R> {
416 pub fn new(mut reader: R) -> Result<Self> {
418 let mut header_buf = [0u8; HctHeader::SIZE];
420 reader
421 .read_exact(&mut header_buf)
422 .map_err(|e| Error::algorithm("hct", format!("failed to read header: {}", e)))?;
423 let header = HctHeader::from_bytes(&header_buf)?;
424
425 let index_size = header.num_blocks as usize * BlockIndex::SIZE;
427 let mut index_buf = vec![0u8; index_size];
428 reader
429 .read_exact(&mut index_buf)
430 .map_err(|e| Error::algorithm("hct", format!("failed to read block index: {}", e)))?;
431
432 let block_index: Vec<BlockIndex> = index_buf
433 .chunks_exact(BlockIndex::SIZE)
434 .map(|chunk| BlockIndex::from_bytes(chunk.try_into().unwrap()))
435 .collect();
436
437 let data_offset = HctHeader::SIZE as u64 + index_size as u64;
438
439 Ok(Self {
440 reader,
441 header,
442 block_index,
443 data_offset,
444 })
445 }
446
447 pub fn header(&self) -> &HctHeader {
449 &self.header
450 }
451
452 pub fn num_blocks(&self) -> usize {
454 self.block_index.len()
455 }
456
457 pub fn read_block(&mut self, block_idx: usize) -> Result<Vec<u8>> {
459 if block_idx >= self.block_index.len() {
460 return Err(Error::corrupted(format!(
461 "block index out of range: {} >= {}",
462 block_idx,
463 self.block_index.len()
464 )));
465 }
466
467 let index = &self.block_index[block_idx];
468 let offset = self.data_offset + index.offset as u64;
469
470 self.reader.seek(SeekFrom::Start(offset)).map_err(|e| {
471 Error::algorithm(
472 "hct",
473 format!("failed to seek to block {}: {}", block_idx, e),
474 )
475 })?;
476
477 let mut buf = vec![0u8; index.compressed_size as usize];
478 self.reader.read_exact(&mut buf).map_err(|e| {
479 Error::algorithm("hct", format!("failed to read block {}: {}", block_idx, e))
480 })?;
481
482 Ok(buf)
483 }
484
485 pub fn decompress_block(
487 &mut self,
488 block_idx: usize,
489 decompressor: &impl Decompressor,
490 ) -> Result<Vec<u8>> {
491 let compressed = self.read_block(block_idx)?;
492
493 let is_last_block = block_idx == self.block_index.len() - 1;
495 let expected_size = if is_last_block {
496 let full_blocks = (self.block_index.len() - 1) as u64 * self.header.block_size as u64;
497 (self.header.original_size - full_blocks) as usize
498 } else {
499 self.header.block_size as usize
500 };
501
502 decompressor.decompress_with_size(&compressed, expected_size)
503 }
504
505 pub fn decompress_all(&mut self, decompressor: &impl Decompressor) -> Result<Vec<u8>> {
507 let mut output = Vec::with_capacity(self.header.original_size as usize);
508
509 for block_idx in 0..self.block_index.len() {
510 let decompressed = self.decompress_block(block_idx, decompressor)?;
511 output.extend_from_slice(&decompressed);
512 }
513
514 Ok(output)
515 }
516}
517
518pub struct HctWriter<W: Write + Seek> {
520 writer: W,
521 algorithm: CompressionAlgorithm,
522 dtype: DType,
523 block_size: u32,
524 shape: Vec<u64>,
525 blocks: Vec<Vec<u8>>,
526 original_size: u64,
527}
528
529impl<W: Write + Seek> HctWriter<W> {
530 pub fn new(writer: W, algorithm: CompressionAlgorithm, dtype: DType, shape: Vec<u64>) -> Self {
532 Self {
533 writer,
534 algorithm,
535 dtype,
536 block_size: DEFAULT_BLOCK_SIZE,
537 shape,
538 blocks: Vec::new(),
539 original_size: 0,
540 }
541 }
542
543 pub fn with_block_size(mut self, block_size: u32) -> Self {
545 self.block_size = block_size;
546 self
547 }
548
549 pub fn add_block(&mut self, compressed: Vec<u8>, original_len: usize) {
551 self.blocks.push(compressed);
552 self.original_size += original_len as u64;
553 }
554
555 pub fn compress_data(&mut self, data: &[u8], compressor: &impl Compressor) -> Result<()> {
557 for chunk in data.chunks(self.block_size as usize) {
558 let compressed = compressor.compress(chunk)?;
559 self.add_block(compressed, chunk.len());
560 }
561 Ok(())
562 }
563
564 pub fn finish(mut self) -> Result<()> {
566 let mut block_index = Vec::with_capacity(self.blocks.len());
568 let mut offset = 0u32;
569
570 for block in &self.blocks {
571 block_index.push(BlockIndex {
572 offset,
573 compressed_size: block.len() as u32,
574 });
575 offset += block.len() as u32;
576 }
577
578 let compressed_size = offset as u64;
579
580 let header = HctHeader {
582 algorithm: self.algorithm,
583 dtype: self.dtype,
584 flags: 0,
585 original_size: self.original_size,
586 compressed_size,
587 block_size: self.block_size,
588 num_blocks: self.blocks.len() as u32,
589 shape: self.shape,
590 };
591
592 self.writer
594 .write_all(&header.to_bytes())
595 .map_err(|e| Error::algorithm("hct", format!("failed to write header: {}", e)))?;
596
597 for index in &block_index {
599 self.writer.write_all(&index.to_bytes()).map_err(|e| {
600 Error::algorithm("hct", format!("failed to write block index: {}", e))
601 })?;
602 }
603
604 for block in &self.blocks {
606 self.writer.write_all(block).map_err(|e| {
607 Error::algorithm("hct", format!("failed to write block data: {}", e))
608 })?;
609 }
610
611 self.writer
612 .flush()
613 .map_err(|e| Error::algorithm("hct", format!("failed to flush: {}", e)))?;
614
615 Ok(())
616 }
617}
618
619pub fn compress_file(
621 input_path: impl AsRef<Path>,
622 output_path: impl AsRef<Path>,
623 compressor: &impl Compressor,
624 dtype: DType,
625 shape: Vec<u64>,
626) -> Result<CompressionStats> {
627 use std::time::Instant;
628
629 let start = Instant::now();
630
631 let input_data = std::fs::read(input_path.as_ref())
633 .map_err(|e| Error::algorithm("hct", format!("failed to read input file: {}", e)))?;
634 let original_size = input_data.len();
635
636 let output_file = File::create(output_path.as_ref())
638 .map_err(|e| Error::algorithm("hct", format!("failed to create output file: {}", e)))?;
639
640 let algorithm = match compressor.algorithm() {
642 haagenti_core::Algorithm::Lz4 => CompressionAlgorithm::Lz4,
643 haagenti_core::Algorithm::Zstd => CompressionAlgorithm::Zstd,
644 _ => return Err(Error::corrupted("unsupported algorithm for HCT")),
645 };
646
647 let mut writer = HctWriter::new(output_file, algorithm, dtype, shape);
649 writer.compress_data(&input_data, compressor)?;
650 writer.finish()?;
651
652 let output_metadata = std::fs::metadata(output_path.as_ref())
654 .map_err(|e| Error::algorithm("hct", format!("failed to get output metadata: {}", e)))?;
655 let compressed_size = output_metadata.len() as usize;
656
657 let elapsed = start.elapsed();
658
659 Ok(CompressionStats {
660 original_size,
661 compressed_size,
662 ratio: original_size as f64 / compressed_size as f64,
663 elapsed_ms: elapsed.as_millis() as u64,
664 })
665}
666
667#[derive(Debug, Clone)]
669pub struct CompressionStats {
670 pub original_size: usize,
671 pub compressed_size: usize,
672 pub ratio: f64,
673 pub elapsed_ms: u64,
674}
675
676pub struct HctWriterV2<W: Write + Seek> {
680 writer: W,
681 algorithm: CompressionAlgorithm,
682 dtype: DType,
683 block_size: u32,
684 shape: Vec<u64>,
685 blocks: Vec<(Vec<u8>, u64)>, original_size: u64,
687 flags: u16,
688 quantization: Option<QuantizationMetadata>,
689}
690
691impl<W: Write + Seek> HctWriterV2<W> {
692 pub fn new(writer: W, algorithm: CompressionAlgorithm, dtype: DType, shape: Vec<u64>) -> Self {
694 Self {
695 writer,
696 algorithm,
697 dtype,
698 block_size: DEFAULT_BLOCK_SIZE,
699 shape,
700 blocks: Vec::new(),
701 original_size: 0,
702 flags: FLAG_HEADER_CHECKSUM | FLAG_BLOCK_CHECKSUMS,
703 quantization: None,
704 }
705 }
706
707 pub fn with_block_size(mut self, block_size: u32) -> Self {
709 self.block_size = block_size;
710 self
711 }
712
713 pub fn with_quantization(mut self, quant: QuantizationMetadata) -> Self {
715 self.quantization = Some(quant);
716 self.flags |= FLAG_QUANTIZATION;
717 self
718 }
719
720 pub fn without_block_checksums(mut self) -> Self {
722 self.flags &= !FLAG_BLOCK_CHECKSUMS;
723 self
724 }
725
726 pub fn add_block(&mut self, compressed: Vec<u8>, original_len: usize) {
728 let checksum = if self.flags & FLAG_BLOCK_CHECKSUMS != 0 {
729 xxh3_64(&compressed)
730 } else {
731 0
732 };
733 self.blocks.push((compressed, checksum));
734 self.original_size += original_len as u64;
735 }
736
737 pub fn compress_data(&mut self, data: &[u8], compressor: &impl Compressor) -> Result<()> {
739 for chunk in data.chunks(self.block_size as usize) {
740 let compressed = compressor.compress(chunk)?;
741 self.add_block(compressed, chunk.len());
742 }
743 Ok(())
744 }
745
746 pub fn finish(mut self) -> Result<()> {
748 let mut block_index = Vec::with_capacity(self.blocks.len());
750 let mut offset = 0u32;
751
752 for (block, checksum) in &self.blocks {
753 block_index.push(BlockIndexV2 {
754 offset,
755 compressed_size: block.len() as u32,
756 checksum: *checksum,
757 });
758 offset += block.len() as u32;
759 }
760
761 let compressed_size = offset as u64;
762
763 let mut header_bytes = [0u8; HctHeader::SIZE];
765
766 header_bytes[0..4].copy_from_slice(&HCT_MAGIC);
768
769 header_bytes[4..8].copy_from_slice(&HCT_VERSION_V2.to_le_bytes());
771
772 header_bytes[8] = self.algorithm as u8;
774 header_bytes[9] = self.dtype as u8;
775
776 header_bytes[10..12].copy_from_slice(&self.flags.to_le_bytes());
778
779 header_bytes[12..20].copy_from_slice(&self.original_size.to_le_bytes());
781 header_bytes[20..28].copy_from_slice(&compressed_size.to_le_bytes());
782 header_bytes[28..32].copy_from_slice(&self.block_size.to_le_bytes());
783 header_bytes[32..36].copy_from_slice(&(self.blocks.len() as u32).to_le_bytes());
784
785 header_bytes[36] = self.shape.len() as u8;
787 for (i, &dim) in self.shape.iter().take(4).enumerate() {
788 let off = 37 + i * 8;
789 header_bytes[off..off + 8].copy_from_slice(&dim.to_le_bytes());
790 }
791
792 let header_checksum = xxh3_64(&header_bytes[..56]); header_bytes[56..64].copy_from_slice(&header_checksum.to_le_bytes());
796
797 self.writer
799 .write_all(&header_bytes)
800 .map_err(|e| Error::algorithm("hct", format!("failed to write header: {}", e)))?;
801
802 if let Some(ref quant) = self.quantization {
804 self.writer.write_all(&quant.to_bytes()).map_err(|e| {
805 Error::algorithm("hct", format!("failed to write quantization: {}", e))
806 })?;
807 }
808
809 for index in &block_index {
811 self.writer.write_all(&index.to_bytes()).map_err(|e| {
812 Error::algorithm("hct", format!("failed to write block index: {}", e))
813 })?;
814 }
815
816 for (block, _) in &self.blocks {
818 self.writer.write_all(block).map_err(|e| {
819 Error::algorithm("hct", format!("failed to write block data: {}", e))
820 })?;
821 }
822
823 self.writer
824 .flush()
825 .map_err(|e| Error::algorithm("hct", format!("failed to flush: {}", e)))?;
826
827 Ok(())
828 }
829}
830
831pub struct HctReaderV2<R: Read + Seek> {
833 reader: R,
834 header: HctHeader,
835 block_index: Vec<BlockIndexV2>,
836 data_offset: u64,
837 quantization: Option<QuantizationMetadata>,
838}
839
840impl<R: Read + Seek> HctReaderV2<R> {
841 pub fn new(mut reader: R) -> Result<Self> {
843 let mut header_buf = [0u8; HctHeader::SIZE];
845 reader
846 .read_exact(&mut header_buf)
847 .map_err(|e| Error::algorithm("hct", format!("failed to read header: {}", e)))?;
848
849 let header = HctHeader::from_bytes(&header_buf)?;
851
852 let stored_checksum = u64::from_le_bytes(header_buf[56..64].try_into().unwrap());
854
855 let version = u32::from_le_bytes(header_buf[4..8].try_into().unwrap());
857 if version >= HCT_VERSION_V2 && header.flags & FLAG_HEADER_CHECKSUM != 0 {
858 let computed = xxh3_64(&header_buf[..56]);
859 if computed != stored_checksum {
860 return Err(Error::corrupted(format!(
861 "header checksum mismatch: expected {:016x}, got {:016x}",
862 stored_checksum, computed
863 )));
864 }
865 }
866
867 let quantization = if header.flags & FLAG_QUANTIZATION != 0 {
869 let mut quant_buf = [0u8; QuantizationMetadata::SIZE];
870 reader.read_exact(&mut quant_buf).map_err(|e| {
871 Error::algorithm("hct", format!("failed to read quantization: {}", e))
872 })?;
873 Some(QuantizationMetadata::from_bytes(&quant_buf)?)
874 } else {
875 None
876 };
877
878 let index_entry_size =
880 if version >= HCT_VERSION_V2 && header.flags & FLAG_BLOCK_CHECKSUMS != 0 {
881 BlockIndexV2::SIZE
882 } else {
883 BlockIndex::SIZE
884 };
885
886 let index_size = header.num_blocks as usize * index_entry_size;
888 let mut index_buf = vec![0u8; index_size];
889 reader
890 .read_exact(&mut index_buf)
891 .map_err(|e| Error::algorithm("hct", format!("failed to read block index: {}", e)))?;
892
893 let block_index: Vec<BlockIndexV2> = if index_entry_size == BlockIndexV2::SIZE {
894 index_buf
895 .chunks_exact(BlockIndexV2::SIZE)
896 .map(|chunk| BlockIndexV2::from_bytes(chunk.try_into().unwrap()))
897 .collect()
898 } else {
899 index_buf
901 .chunks_exact(BlockIndex::SIZE)
902 .map(|chunk| {
903 let v1 = BlockIndex::from_bytes(chunk.try_into().unwrap());
904 BlockIndexV2::from_v1(v1)
905 })
906 .collect()
907 };
908
909 let quant_size = if quantization.is_some() {
910 QuantizationMetadata::SIZE
911 } else {
912 0
913 };
914 let data_offset = HctHeader::SIZE as u64 + quant_size as u64 + index_size as u64;
915
916 Ok(Self {
917 reader,
918 header,
919 block_index,
920 data_offset,
921 quantization,
922 })
923 }
924
925 pub fn header(&self) -> &HctHeader {
927 &self.header
928 }
929
930 pub fn quantization(&self) -> Option<&QuantizationMetadata> {
932 self.quantization.as_ref()
933 }
934
935 pub fn num_blocks(&self) -> usize {
937 self.block_index.len()
938 }
939
940 pub fn read_block_validated(&mut self, block_idx: usize) -> Result<Vec<u8>> {
942 if block_idx >= self.block_index.len() {
943 return Err(Error::corrupted(format!(
944 "block index out of range: {} >= {}",
945 block_idx,
946 self.block_index.len()
947 )));
948 }
949
950 let index = &self.block_index[block_idx];
951 let offset = self.data_offset + index.offset as u64;
952
953 self.reader.seek(SeekFrom::Start(offset)).map_err(|e| {
954 Error::algorithm(
955 "hct",
956 format!("failed to seek to block {}: {}", block_idx, e),
957 )
958 })?;
959
960 let mut buf = vec![0u8; index.compressed_size as usize];
961 self.reader.read_exact(&mut buf).map_err(|e| {
962 Error::algorithm("hct", format!("failed to read block {}: {}", block_idx, e))
963 })?;
964
965 if index.checksum != 0 {
967 let computed = xxh3_64(&buf);
968 if computed != index.checksum {
969 return Err(Error::corrupted(format!(
970 "block {} checksum mismatch: expected {:016x}, got {:016x}",
971 block_idx, index.checksum, computed
972 )));
973 }
974 }
975
976 Ok(buf)
977 }
978
979 pub fn decompress_block_validated(
981 &mut self,
982 block_idx: usize,
983 decompressor: &impl Decompressor,
984 ) -> Result<Vec<u8>> {
985 let compressed = self.read_block_validated(block_idx)?;
986
987 let is_last_block = block_idx == self.block_index.len() - 1;
989 let expected_size = if is_last_block {
990 let full_blocks = (self.block_index.len() - 1) as u64 * self.header.block_size as u64;
991 (self.header.original_size - full_blocks) as usize
992 } else {
993 self.header.block_size as usize
994 };
995
996 decompressor.decompress_with_size(&compressed, expected_size)
997 }
998
999 pub fn decompress_all_validated(
1001 &mut self,
1002 decompressor: &impl Decompressor,
1003 ) -> Result<Vec<u8>> {
1004 let mut output = Vec::with_capacity(self.header.original_size as usize);
1005
1006 for block_idx in 0..self.block_index.len() {
1007 let decompressed = self.decompress_block_validated(block_idx, decompressor)?;
1008 output.extend_from_slice(&decompressed);
1009 }
1010
1011 Ok(output)
1012 }
1013
1014 pub fn validate_checksums(&mut self) -> Result<()> {
1016 for block_idx in 0..self.block_index.len() {
1017 let _ = self.read_block_validated(block_idx)?;
1018 }
1019 Ok(())
1020 }
1021}
1022
1023#[derive(Debug, Clone)]
1025pub struct ChecksumError {
1026 pub expected: u64,
1028 pub actual: u64,
1030 pub block_index: Option<usize>,
1032}
1033
1034impl std::fmt::Display for ChecksumError {
1035 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1036 match self.block_index {
1037 Some(idx) => write!(
1038 f,
1039 "block {} checksum mismatch: expected {:016x}, got {:016x}",
1040 idx, self.expected, self.actual
1041 ),
1042 None => write!(
1043 f,
1044 "header checksum mismatch: expected {:016x}, got {:016x}",
1045 self.expected, self.actual
1046 ),
1047 }
1048 }
1049}
1050
1051#[cfg(test)]
1052mod tests {
1053 use super::*;
1054
1055 #[test]
1056 fn test_header_roundtrip() {
1057 let header = HctHeader {
1058 algorithm: CompressionAlgorithm::Zstd,
1059 dtype: DType::I4,
1060 flags: 0,
1061 original_size: 1024 * 1024,
1062 compressed_size: 256 * 1024,
1063 block_size: 64 * 1024,
1064 num_blocks: 16,
1065 shape: vec![4096, 4096],
1066 };
1067
1068 let bytes = header.to_bytes();
1069 let parsed = HctHeader::from_bytes(&bytes).unwrap();
1070
1071 assert_eq!(parsed.algorithm, header.algorithm);
1072 assert_eq!(parsed.dtype, header.dtype);
1073 assert_eq!(parsed.original_size, header.original_size);
1074 assert_eq!(parsed.compressed_size, header.compressed_size);
1075 assert_eq!(parsed.num_blocks, header.num_blocks);
1076 assert_eq!(parsed.shape, header.shape);
1077 }
1078
1079 #[test]
1080 fn test_block_index_roundtrip() {
1081 let index = BlockIndex {
1082 offset: 12345,
1083 compressed_size: 6789,
1084 };
1085
1086 let bytes = index.to_bytes();
1087 let parsed = BlockIndex::from_bytes(&bytes);
1088
1089 assert_eq!(parsed.offset, index.offset);
1090 assert_eq!(parsed.compressed_size, index.compressed_size);
1091 }
1092
1093 #[test]
1094 #[cfg(feature = "zstd")]
1095 fn test_hct_zstd_roundtrip() {
1096 use haagenti_zstd::{ZstdCompressor, ZstdDecompressor};
1097 use std::io::Cursor;
1098
1099 let original_data: Vec<u8> = (0..65536).map(|i| ((i % 256) as i8) as u8).collect();
1101
1102 let mut buffer = Vec::new();
1104 {
1105 let cursor = Cursor::new(&mut buffer);
1106 let compressor = ZstdCompressor::new();
1107
1108 let mut writer = HctWriter::new(
1109 cursor,
1110 CompressionAlgorithm::Zstd,
1111 DType::I8,
1112 vec![256, 256],
1113 );
1114 writer.compress_data(&original_data, &compressor).unwrap();
1115 writer.finish().unwrap();
1116 }
1117
1118 assert!(buffer.len() < original_data.len(), "Should compress");
1120 assert!(&buffer[0..4] == &HCT_MAGIC, "Should start with HCT magic");
1121
1122 let cursor = Cursor::new(&buffer);
1124 let mut reader = HctReader::new(cursor).unwrap();
1125 let decompressor = ZstdDecompressor::new();
1126
1127 assert_eq!(reader.num_blocks(), 4, "Should have 4 blocks");
1129
1130 let decompressed = reader.decompress_all(&decompressor).unwrap();
1131 assert_eq!(decompressed, original_data);
1132 }
1133
1134 #[test]
1135 #[cfg(feature = "lz4")]
1136 fn test_hct_lz4_roundtrip() {
1137 use haagenti_lz4::Lz4Compressor;
1138 use std::io::Cursor;
1139
1140 let mut original_data = vec![0u8; 65536];
1142 for i in (0..65536).step_by(100) {
1143 original_data[i] = (i % 256) as u8;
1144 }
1145
1146 let mut buffer = Vec::new();
1148 let cursor = Cursor::new(&mut buffer);
1149 let compressor = Lz4Compressor::new();
1150
1151 let mut writer =
1152 HctWriter::new(cursor, CompressionAlgorithm::Lz4, DType::I8, vec![256, 256]);
1153 writer.compress_data(&original_data, &compressor).unwrap();
1154 writer.finish().unwrap();
1155
1156 assert!(buffer[0..4] == HCT_MAGIC);
1157
1158 use haagenti_lz4::Lz4Decompressor;
1160
1161 let cursor = Cursor::new(&buffer);
1162 let mut reader = HctReader::new(cursor).unwrap();
1163 let decompressor = Lz4Decompressor::new();
1164
1165 let algorithm = reader.header().algorithm;
1167 let dtype = reader.header().dtype;
1168 let original_size = reader.header().original_size;
1169 let block_size = reader.header().block_size;
1170 let num_blocks = reader.num_blocks();
1171
1172 assert_eq!(algorithm, CompressionAlgorithm::Lz4);
1173 assert_eq!(dtype, DType::I8);
1174 assert_eq!(original_size, 65536);
1175
1176 for i in 0..num_blocks {
1178 let block = reader.decompress_block(i, &decompressor).unwrap();
1179 let expected_len = if i == num_blocks - 1 {
1180 (original_size as usize) % (block_size as usize)
1181 } else {
1182 block_size as usize
1183 };
1184 let expected_len = if expected_len == 0 {
1186 block_size as usize
1187 } else {
1188 expected_len
1189 };
1190 assert_eq!(block.len(), expected_len);
1191 }
1192 }
1193
1194 #[test]
1195 #[cfg(feature = "zstd")]
1196 fn test_hct_block_random_access() {
1197 use haagenti_zstd::{ZstdCompressor, ZstdDecompressor};
1198 use std::io::Cursor;
1199
1200 let block_size = 1024u32;
1202 let num_blocks = 4usize;
1203 let mut original_data = Vec::new();
1204 for block_idx in 0..num_blocks {
1205 for _ in 0..block_size {
1206 original_data.push(block_idx as u8 * 10);
1207 }
1208 }
1209
1210 let mut buffer = Vec::new();
1212 let cursor = Cursor::new(&mut buffer);
1213 let compressor = ZstdCompressor::new();
1214
1215 let mut writer = HctWriter::new(
1216 cursor,
1217 CompressionAlgorithm::Zstd,
1218 DType::I8,
1219 vec![num_blocks as u64, block_size as u64],
1220 )
1221 .with_block_size(block_size);
1222 writer.compress_data(&original_data, &compressor).unwrap();
1223 writer.finish().unwrap();
1224
1225 let cursor = Cursor::new(&buffer);
1227 let mut reader = HctReader::new(cursor).unwrap();
1228 let decompressor = ZstdDecompressor::new();
1229
1230 let block2 = reader.decompress_block(2, &decompressor).unwrap();
1232 assert!(block2.iter().all(|&b| b == 20));
1233
1234 let block0 = reader.decompress_block(0, &decompressor).unwrap();
1236 assert!(block0.iter().all(|&b| b == 0));
1237
1238 let block3 = reader.decompress_block(3, &decompressor).unwrap();
1240 assert!(block3.iter().all(|&b| b == 30));
1241
1242 let block1 = reader.decompress_block(1, &decompressor).unwrap();
1244 assert!(block1.iter().all(|&b| b == 10));
1245 }
1246
1247 #[test]
1250 fn test_quantization_metadata_roundtrip() {
1251 let quant = QuantizationMetadata {
1252 scheme: QuantizationScheme::GptqInt4,
1253 group_size: 128,
1254 scale_bits: 0x3C00, zero_point: -8,
1256 has_per_group_scales: true,
1257 };
1258
1259 let bytes = quant.to_bytes();
1260 let parsed = QuantizationMetadata::from_bytes(&bytes).unwrap();
1261
1262 assert_eq!(parsed.scheme, QuantizationScheme::GptqInt4);
1263 assert_eq!(parsed.group_size, 128);
1264 assert_eq!(parsed.scale_bits, 0x3C00);
1265 assert_eq!(parsed.zero_point, -8);
1266 assert!(parsed.has_per_group_scales);
1267 }
1268
1269 #[test]
1270 fn test_block_index_v2_roundtrip() {
1271 let index = BlockIndexV2 {
1272 offset: 12345,
1273 compressed_size: 6789,
1274 checksum: 0xDEAD_BEEF_CAFE_BABE,
1275 };
1276
1277 let bytes = index.to_bytes();
1278 let parsed = BlockIndexV2::from_bytes(&bytes);
1279
1280 assert_eq!(parsed.offset, index.offset);
1281 assert_eq!(parsed.compressed_size, index.compressed_size);
1282 assert_eq!(parsed.checksum, index.checksum);
1283 }
1284
1285 #[test]
1286 #[cfg(feature = "lz4")]
1287 fn test_hct_v2_checksum_valid() {
1288 use haagenti_lz4::{Lz4Compressor, Lz4Decompressor};
1289 use std::io::Cursor;
1290
1291 let original_data: Vec<u8> = (0..16384).map(|i| (i % 256) as u8).collect();
1293
1294 let mut buffer = Vec::new();
1296 {
1297 let cursor = Cursor::new(&mut buffer);
1298 let compressor = Lz4Compressor::new();
1299
1300 let mut writer =
1301 HctWriterV2::new(cursor, CompressionAlgorithm::Lz4, DType::I8, vec![16384]);
1302 writer.compress_data(&original_data, &compressor).unwrap();
1303 writer.finish().unwrap();
1304 }
1305
1306 let cursor = Cursor::new(&buffer);
1308 let mut reader = HctReaderV2::new(cursor).unwrap();
1309
1310 reader.validate_checksums().unwrap();
1312
1313 let cursor = Cursor::new(&buffer);
1315 let mut reader = HctReaderV2::new(cursor).unwrap();
1316 let decompressor = Lz4Decompressor::new();
1317 let decompressed = reader.decompress_all_validated(&decompressor).unwrap();
1318
1319 assert_eq!(decompressed, original_data);
1320 }
1321
1322 #[test]
1323 #[cfg(feature = "lz4")]
1324 fn test_hct_v2_checksum_detects_corruption() {
1325 use haagenti_lz4::Lz4Compressor;
1326 use std::io::Cursor;
1327
1328 let original_data: Vec<u8> = (0..16384).map(|i| (i % 256) as u8).collect();
1330
1331 let mut buffer = Vec::new();
1333 {
1334 let cursor = Cursor::new(&mut buffer);
1335 let compressor = Lz4Compressor::new();
1336
1337 let mut writer =
1338 HctWriterV2::new(cursor, CompressionAlgorithm::Lz4, DType::I8, vec![16384]);
1339 writer.compress_data(&original_data, &compressor).unwrap();
1340 writer.finish().unwrap();
1341 }
1342
1343 let corruption_offset = 100; buffer[corruption_offset] ^= 0xFF;
1347
1348 let cursor = Cursor::new(&buffer);
1350 let result = HctReaderV2::new(cursor);
1351
1352 match result {
1354 Err(_) => {
1355 }
1357 Ok(mut reader) => {
1358 let validate_result = reader.validate_checksums();
1360 assert!(validate_result.is_err(), "Should detect block corruption");
1361 }
1362 }
1363 }
1364
1365 #[test]
1366 #[cfg(feature = "lz4")]
1367 fn test_hct_v2_with_quantization_metadata() {
1368 use haagenti_lz4::{Lz4Compressor, Lz4Decompressor};
1369 use std::io::Cursor;
1370
1371 let original_data: Vec<u8> = (0..4096).map(|i| (i % 256) as u8).collect();
1372
1373 let quant = QuantizationMetadata {
1374 scheme: QuantizationScheme::GptqInt4,
1375 group_size: 128,
1376 scale_bits: 0x3C00,
1377 zero_point: 0,
1378 has_per_group_scales: false,
1379 };
1380
1381 let mut buffer = Vec::new();
1383 {
1384 let cursor = Cursor::new(&mut buffer);
1385 let compressor = Lz4Compressor::new();
1386
1387 let mut writer =
1388 HctWriterV2::new(cursor, CompressionAlgorithm::Lz4, DType::I4, vec![4096])
1389 .with_quantization(quant);
1390 writer.compress_data(&original_data, &compressor).unwrap();
1391 writer.finish().unwrap();
1392 }
1393
1394 let cursor = Cursor::new(&buffer);
1396 let mut reader = HctReaderV2::new(cursor).unwrap();
1397
1398 assert!(reader.quantization().is_some());
1399 let read_quant = reader.quantization().unwrap();
1400 assert_eq!(read_quant.scheme, QuantizationScheme::GptqInt4);
1401 assert_eq!(read_quant.group_size, 128);
1402 assert_eq!(read_quant.scale_bits, 0x3C00);
1403
1404 let decompressor = Lz4Decompressor::new();
1406 let decompressed = reader.decompress_all_validated(&decompressor).unwrap();
1407 assert_eq!(decompressed, original_data);
1408 }
1409
1410 #[test]
1411 #[cfg(feature = "lz4")]
1412 fn test_hct_v2_backward_compatible_with_v1_reader() {
1413 use haagenti_lz4::{Lz4Compressor, Lz4Decompressor};
1414 use std::io::Cursor;
1415
1416 let original_data: Vec<u8> = (0..8192).map(|i| (i % 256) as u8).collect();
1418
1419 let mut buffer = Vec::new();
1421 {
1422 let cursor = Cursor::new(&mut buffer);
1423 let compressor = Lz4Compressor::new();
1424
1425 let mut writer =
1426 HctWriter::new(cursor, CompressionAlgorithm::Lz4, DType::I8, vec![8192]);
1427 writer.compress_data(&original_data, &compressor).unwrap();
1428 writer.finish().unwrap();
1429 }
1430
1431 let cursor = Cursor::new(&buffer);
1433 let mut reader = HctReader::new(cursor).unwrap();
1434 let decompressor = Lz4Decompressor::new();
1435 let decompressed = reader.decompress_all(&decompressor).unwrap();
1436
1437 assert_eq!(decompressed, original_data);
1438 }
1439
1440 #[test]
1447 fn test_corrupted_magic_number() {
1448 let mut bytes = HctHeader {
1449 algorithm: CompressionAlgorithm::Lz4,
1450 dtype: DType::F32,
1451 flags: 0,
1452 original_size: 1024,
1453 compressed_size: 512,
1454 block_size: 1024,
1455 num_blocks: 1,
1456 shape: vec![32, 32],
1457 }
1458 .to_bytes();
1459
1460 bytes[0] = 0xFF;
1462 bytes[1] = 0xFF;
1463
1464 let result = HctHeader::from_bytes(&bytes);
1465 assert!(result.is_err(), "Should reject corrupted magic");
1466 }
1467
1468 #[test]
1469 fn test_corrupted_version() {
1470 let mut bytes = HctHeader {
1471 algorithm: CompressionAlgorithm::Lz4,
1472 dtype: DType::F32,
1473 flags: 0,
1474 original_size: 1024,
1475 compressed_size: 512,
1476 block_size: 1024,
1477 num_blocks: 1,
1478 shape: vec![32, 32],
1479 }
1480 .to_bytes();
1481
1482 bytes[4] = 0xFF;
1484
1485 let result = HctHeader::from_bytes(&bytes);
1486 assert!(result.is_err(), "Should reject unsupported version");
1487 }
1488
1489 #[test]
1490 fn test_corrupted_algorithm_field() {
1491 let mut bytes = HctHeader {
1492 algorithm: CompressionAlgorithm::Lz4,
1493 dtype: DType::F32,
1494 flags: 0,
1495 original_size: 1024,
1496 compressed_size: 512,
1497 block_size: 1024,
1498 num_blocks: 1,
1499 shape: vec![32, 32],
1500 }
1501 .to_bytes();
1502
1503 bytes[5] = 0xFF;
1505
1506 let result = HctHeader::from_bytes(&bytes);
1507 assert!(result.is_err(), "Should reject invalid algorithm");
1508 }
1509
1510 #[test]
1511 fn test_corrupted_dtype_field() {
1512 let mut bytes = HctHeader {
1513 algorithm: CompressionAlgorithm::Lz4,
1514 dtype: DType::F32,
1515 flags: 0,
1516 original_size: 1024,
1517 compressed_size: 512,
1518 block_size: 1024,
1519 num_blocks: 1,
1520 shape: vec![32, 32],
1521 }
1522 .to_bytes();
1523
1524 bytes[6] = 0xFF;
1526
1527 let result = HctHeader::from_bytes(&bytes);
1528 assert!(result.is_err(), "Should reject invalid dtype");
1529 }
1530
1531 #[test]
1534 fn test_truncated_header() {
1535 let small_buf: [u8; 16] = [0; 16];
1538
1539 assert!(
1542 HctHeader::SIZE >= 32,
1543 "Header should require at least 32 bytes"
1544 );
1545 }
1546
1547 #[test]
1550 fn test_quantization_scheme_symmetric_int8() {
1551 let quant = QuantizationMetadata {
1552 scheme: QuantizationScheme::SymmetricInt8,
1553 group_size: 64,
1554 scale_bits: 0x4000, zero_point: 0,
1556 has_per_group_scales: false,
1557 };
1558
1559 let bytes = quant.to_bytes();
1560 let parsed = QuantizationMetadata::from_bytes(&bytes).unwrap();
1561
1562 assert_eq!(parsed.scheme, QuantizationScheme::SymmetricInt8);
1563 assert_eq!(parsed.group_size, 64);
1564 assert_eq!(parsed.zero_point, 0);
1565 }
1566
1567 #[test]
1568 fn test_quantization_scheme_asymmetric_int8() {
1569 let quant = QuantizationMetadata {
1570 scheme: QuantizationScheme::AsymmetricInt8,
1571 group_size: 32,
1572 scale_bits: 0x3C00,
1573 zero_point: -128, has_per_group_scales: true,
1575 };
1576
1577 let bytes = quant.to_bytes();
1578 let parsed = QuantizationMetadata::from_bytes(&bytes).unwrap();
1579
1580 assert_eq!(parsed.scheme, QuantizationScheme::AsymmetricInt8);
1581 assert_eq!(parsed.zero_point, -128);
1582 assert!(parsed.has_per_group_scales);
1583 }
1584
1585 #[test]
1586 fn test_quantization_scheme_awq_int4() {
1587 let quant = QuantizationMetadata {
1588 scheme: QuantizationScheme::AwqInt4,
1589 group_size: 128,
1590 scale_bits: 0x3800, zero_point: 8,
1592 has_per_group_scales: true,
1593 };
1594
1595 let bytes = quant.to_bytes();
1596 let parsed = QuantizationMetadata::from_bytes(&bytes).unwrap();
1597
1598 assert_eq!(parsed.scheme, QuantizationScheme::AwqInt4);
1599 assert_eq!(parsed.group_size, 128);
1600 }
1601
1602 #[test]
1603 fn test_quantization_scheme_gptq_int4() {
1604 let quant = QuantizationMetadata {
1605 scheme: QuantizationScheme::GptqInt4,
1606 group_size: 128,
1607 scale_bits: 0x3C00,
1608 zero_point: 0,
1609 has_per_group_scales: true,
1610 };
1611
1612 let bytes = quant.to_bytes();
1613 let parsed = QuantizationMetadata::from_bytes(&bytes).unwrap();
1614
1615 assert_eq!(parsed.scheme, QuantizationScheme::GptqInt4);
1616 }
1617
1618 #[test]
1619 fn test_quantization_scheme_none() {
1620 let quant = QuantizationMetadata {
1621 scheme: QuantizationScheme::None,
1622 group_size: 0,
1623 scale_bits: 0,
1624 zero_point: 0,
1625 has_per_group_scales: false,
1626 };
1627
1628 let bytes = quant.to_bytes();
1629 let parsed = QuantizationMetadata::from_bytes(&bytes).unwrap();
1630
1631 assert_eq!(parsed.scheme, QuantizationScheme::None);
1632 }
1633
1634 #[test]
1637 fn test_header_zero_blocks() {
1638 let header = HctHeader {
1639 algorithm: CompressionAlgorithm::Lz4,
1640 dtype: DType::F32,
1641 flags: 0,
1642 original_size: 0,
1643 compressed_size: 0,
1644 block_size: 1024,
1645 num_blocks: 0,
1646 shape: vec![0],
1647 };
1648
1649 let bytes = header.to_bytes();
1650 let parsed = HctHeader::from_bytes(&bytes).unwrap();
1651
1652 assert_eq!(parsed.num_blocks, 0);
1653 assert_eq!(parsed.original_size, 0);
1654 }
1655
1656 #[test]
1657 fn test_header_single_block() {
1658 let header = HctHeader {
1659 algorithm: CompressionAlgorithm::Zstd,
1660 dtype: DType::F32,
1661 flags: 0,
1662 original_size: 512,
1663 compressed_size: 256,
1664 block_size: 1024,
1665 num_blocks: 1,
1666 shape: vec![128],
1667 };
1668
1669 let bytes = header.to_bytes();
1670 let parsed = HctHeader::from_bytes(&bytes).unwrap();
1671
1672 assert_eq!(parsed.num_blocks, 1);
1673 assert!(parsed.original_size < parsed.block_size as u64);
1675 }
1676
1677 #[test]
1678 fn test_header_exact_block_multiple() {
1679 let header = HctHeader {
1680 algorithm: CompressionAlgorithm::Lz4,
1681 dtype: DType::I8,
1682 flags: 0,
1683 original_size: 4096,
1684 compressed_size: 2048,
1685 block_size: 1024,
1686 num_blocks: 4,
1687 shape: vec![4096],
1688 };
1689
1690 let bytes = header.to_bytes();
1691 let parsed = HctHeader::from_bytes(&bytes).unwrap();
1692
1693 assert_eq!(parsed.num_blocks, 4);
1695 assert_eq!(parsed.original_size, 4 * parsed.block_size as u64);
1696 }
1697
1698 #[test]
1699 fn test_header_partial_final_block() {
1700 let header = HctHeader {
1701 algorithm: CompressionAlgorithm::Lz4,
1702 dtype: DType::I8,
1703 flags: 0,
1704 original_size: 4500,
1705 compressed_size: 2250,
1706 block_size: 1024,
1707 num_blocks: 5,
1708 shape: vec![4500],
1709 };
1710
1711 let bytes = header.to_bytes();
1712 let parsed = HctHeader::from_bytes(&bytes).unwrap();
1713
1714 assert_eq!(parsed.num_blocks, 5);
1717 let last_block_size = parsed.original_size as u32 % parsed.block_size;
1718 assert_eq!(last_block_size, 404);
1719 }
1720
1721 #[test]
1724 fn test_header_1d_shape() {
1725 let header = HctHeader {
1726 algorithm: CompressionAlgorithm::Lz4,
1727 dtype: DType::F32,
1728 flags: 0,
1729 original_size: 4096,
1730 compressed_size: 2048,
1731 block_size: 1024,
1732 num_blocks: 4,
1733 shape: vec![1024],
1734 };
1735
1736 let bytes = header.to_bytes();
1737 let parsed = HctHeader::from_bytes(&bytes).unwrap();
1738
1739 assert_eq!(parsed.shape.len(), 1);
1740 assert_eq!(parsed.shape[0], 1024);
1741 }
1742
1743 #[test]
1744 fn test_header_2d_shape() {
1745 let header = HctHeader {
1746 algorithm: CompressionAlgorithm::Lz4,
1747 dtype: DType::F32,
1748 flags: 0,
1749 original_size: 4096,
1750 compressed_size: 2048,
1751 block_size: 1024,
1752 num_blocks: 4,
1753 shape: vec![32, 32],
1754 };
1755
1756 let bytes = header.to_bytes();
1757 let parsed = HctHeader::from_bytes(&bytes).unwrap();
1758
1759 assert_eq!(parsed.shape.len(), 2);
1760 assert_eq!(parsed.shape, vec![32, 32]);
1761 }
1762
1763 #[test]
1764 fn test_header_3d_shape() {
1765 let header = HctHeader {
1766 algorithm: CompressionAlgorithm::Lz4,
1767 dtype: DType::F32,
1768 flags: 0,
1769 original_size: 4096,
1770 compressed_size: 2048,
1771 block_size: 1024,
1772 num_blocks: 4,
1773 shape: vec![4, 16, 64],
1774 };
1775
1776 let bytes = header.to_bytes();
1777 let parsed = HctHeader::from_bytes(&bytes).unwrap();
1778
1779 assert_eq!(parsed.shape.len(), 3);
1780 assert_eq!(parsed.shape, vec![4, 16, 64]);
1781 }
1782
1783 #[test]
1784 fn test_header_max_3_dimensions() {
1785 assert_eq!(HctHeader::SIZE, 64);
1793
1794 let header = HctHeader {
1800 algorithm: CompressionAlgorithm::Zstd,
1801 dtype: DType::BF16,
1802 flags: 0,
1803 original_size: u64::MAX / 2,
1804 compressed_size: u64::MAX / 4,
1805 block_size: u32::MAX,
1806 num_blocks: u32::MAX / 2,
1807 shape: vec![8192, 8192, 128], };
1809
1810 let bytes = header.to_bytes();
1811 let parsed = HctHeader::from_bytes(&bytes).unwrap();
1812
1813 assert_eq!(parsed.shape.len(), 3);
1814 assert_eq!(parsed.shape[0], 8192);
1815 assert_eq!(parsed.shape[1], 8192);
1816 assert_eq!(parsed.shape[2], 128);
1817 }
1818
1819 #[test]
1822 fn test_all_dtypes_roundtrip() {
1823 let dtypes = [DType::F32, DType::F16, DType::BF16, DType::I8, DType::I4];
1824
1825 for dtype in dtypes {
1826 let header = HctHeader {
1827 algorithm: CompressionAlgorithm::Lz4,
1828 dtype,
1829 flags: 0,
1830 original_size: 1024,
1831 compressed_size: 512,
1832 block_size: 1024,
1833 num_blocks: 1,
1834 shape: vec![256],
1835 };
1836
1837 let bytes = header.to_bytes();
1838 let parsed = HctHeader::from_bytes(&bytes).unwrap();
1839
1840 assert_eq!(parsed.dtype, dtype, "DType {:?} should roundtrip", dtype);
1841 }
1842 }
1843
1844 #[test]
1847 fn test_all_algorithms_roundtrip() {
1848 let algorithms = [CompressionAlgorithm::Lz4, CompressionAlgorithm::Zstd];
1849
1850 for algorithm in algorithms {
1851 let header = HctHeader {
1852 algorithm,
1853 dtype: DType::F32,
1854 flags: 0,
1855 original_size: 1024,
1856 compressed_size: 512,
1857 block_size: 1024,
1858 num_blocks: 1,
1859 shape: vec![256],
1860 };
1861
1862 let bytes = header.to_bytes();
1863 let parsed = HctHeader::from_bytes(&bytes).unwrap();
1864
1865 assert_eq!(
1866 parsed.algorithm, algorithm,
1867 "Algorithm {:?} should roundtrip",
1868 algorithm
1869 );
1870 }
1871 }
1872
1873 #[test]
1876 fn test_flags_preserved() {
1877 let flags_to_test = [0x0000, 0x0001, 0x0002, 0xFFFF];
1878
1879 for flags in flags_to_test {
1880 let header = HctHeader {
1881 algorithm: CompressionAlgorithm::Lz4,
1882 dtype: DType::F32,
1883 flags,
1884 original_size: 1024,
1885 compressed_size: 512,
1886 block_size: 1024,
1887 num_blocks: 1,
1888 shape: vec![256],
1889 };
1890
1891 let bytes = header.to_bytes();
1892 let parsed = HctHeader::from_bytes(&bytes).unwrap();
1893
1894 assert_eq!(
1895 parsed.flags, flags,
1896 "Flags {:04X} should be preserved",
1897 flags
1898 );
1899 }
1900 }
1901
1902 #[test]
1905 fn test_large_original_size() {
1906 let header = HctHeader {
1907 algorithm: CompressionAlgorithm::Lz4,
1908 dtype: DType::F32,
1909 flags: 0,
1910 original_size: u64::MAX - 1,
1911 compressed_size: u64::MAX / 2,
1912 block_size: u32::MAX,
1913 num_blocks: u32::MAX,
1914 shape: vec![u64::MAX],
1915 };
1916
1917 let bytes = header.to_bytes();
1918 let parsed = HctHeader::from_bytes(&bytes).unwrap();
1919
1920 assert_eq!(parsed.original_size, u64::MAX - 1);
1921 assert_eq!(parsed.compressed_size, u64::MAX / 2);
1922 assert_eq!(parsed.block_size, u32::MAX);
1923 assert_eq!(parsed.num_blocks, u32::MAX);
1924 }
1925
1926 #[test]
1929 fn test_block_index_large_values() {
1930 let index = BlockIndex {
1931 offset: u32::MAX - 1,
1932 compressed_size: u32::MAX - 1,
1933 };
1934
1935 let bytes = index.to_bytes();
1936 let parsed = BlockIndex::from_bytes(&bytes);
1937
1938 assert_eq!(parsed.offset, u32::MAX - 1);
1939 assert_eq!(parsed.compressed_size, u32::MAX - 1);
1940 }
1941
1942 #[test]
1943 fn test_block_index_v2_checksum_uniqueness() {
1944 let index1 = BlockIndexV2 {
1945 offset: 100,
1946 compressed_size: 50,
1947 checksum: 0xABCD_EF01_2345_6789,
1948 };
1949
1950 let index2 = BlockIndexV2 {
1951 offset: 100,
1952 compressed_size: 50,
1953 checksum: 0x9876_5432_10FE_DCBA,
1954 };
1955
1956 let bytes1 = index1.to_bytes();
1957 let bytes2 = index2.to_bytes();
1958
1959 assert_ne!(bytes1, bytes2);
1961
1962 let parsed1 = BlockIndexV2::from_bytes(&bytes1);
1964 let parsed2 = BlockIndexV2::from_bytes(&bytes2);
1965
1966 assert_eq!(parsed1.checksum, index1.checksum);
1967 assert_eq!(parsed2.checksum, index2.checksum);
1968 }
1969
1970 #[test]
1973 fn test_reader_invalid_block_index_bounds() {
1974 let data = vec![0u8; 4096]; let mut output = Vec::new();
1977
1978 {
1979 let cursor = std::io::Cursor::new(&mut output);
1980 let mut writer =
1981 HctWriter::new(cursor, CompressionAlgorithm::Lz4, DType::F32, vec![64, 16])
1982 .with_block_size(1024);
1983
1984 let codec = haagenti_lz4::Lz4Codec::new();
1985 writer.compress_data(&data, &codec).unwrap();
1986 writer.finish().unwrap();
1987 }
1988
1989 let cursor = std::io::Cursor::new(&output);
1991 let mut reader = HctReader::new(cursor).unwrap();
1992
1993 let result = reader.read_block(999);
1995 assert!(result.is_err(), "Should error on invalid block index");
1996 let err_msg = format!("{}", result.unwrap_err());
1997 assert!(
1998 err_msg.contains("block index out of range") || err_msg.contains("corrupted"),
1999 "Error should mention invalid block: {}",
2000 err_msg
2001 );
2002 }
2003
2004 #[test]
2005 fn test_reader_truncated_header() {
2006 let short_data = vec![0u8; HctHeader::SIZE - 10];
2008 let cursor = std::io::Cursor::new(short_data);
2009
2010 let result = HctReader::new(cursor);
2011 assert!(result.is_err(), "Should error on truncated header");
2012 }
2013
2014 #[test]
2015 fn test_reader_truncated_block_index() {
2016 let mut data = [0u8; HctHeader::SIZE];
2018
2019 data[0..4].copy_from_slice(&HCT_MAGIC);
2021 data[4..8].copy_from_slice(&HCT_VERSION.to_le_bytes());
2023 data[8] = 0;
2025 data[9] = 0;
2027 data[32..36].copy_from_slice(&10u32.to_le_bytes());
2029 data[36] = 1;
2031
2032 let cursor = std::io::Cursor::new(data.to_vec());
2034
2035 let result = HctReader::new(cursor);
2036 assert!(
2037 result.is_err(),
2038 "Should error when block index is truncated"
2039 );
2040 }
2041
2042 #[test]
2043 fn test_v2_checksum_validation_detects_bitflip() {
2044 let data = vec![42u8; 2048]; let mut output = Vec::new();
2047
2048 {
2049 let cursor = std::io::Cursor::new(&mut output);
2050 let mut writer =
2051 HctWriterV2::new(cursor, CompressionAlgorithm::Zstd, DType::F32, vec![32, 16])
2052 .with_block_size(1024);
2053
2054 let codec = haagenti_zstd::ZstdCodec::new();
2055 writer.compress_data(&data, &codec).unwrap();
2056 writer.finish().unwrap();
2057 }
2058
2059 let header_size = HctHeader::SIZE;
2061 if output.len() > header_size + 32 {
2064 let corrupt_pos = header_size + 50; if corrupt_pos < output.len() {
2066 output[corrupt_pos] ^= 0xFF; }
2068 }
2069
2070 let cursor = std::io::Cursor::new(&output);
2072 let reader_result = HctReaderV2::new(cursor);
2073
2074 if let Ok(mut reader) = reader_result {
2079 let block_result = reader.read_block_validated(0);
2081 let _ = block_result;
2084 }
2085 }
2086
2087 #[test]
2088 fn test_empty_data_compression() {
2089 let data: Vec<u8> = vec![];
2091 let mut output = Vec::new();
2092
2093 {
2094 let cursor = std::io::Cursor::new(&mut output);
2095 let mut writer = HctWriter::new(
2096 cursor,
2097 CompressionAlgorithm::Lz4,
2098 DType::F32,
2099 vec![0], )
2101 .with_block_size(1024);
2102
2103 let codec = haagenti_lz4::Lz4Codec::new();
2104 writer.compress_data(&data, &codec).unwrap();
2105 writer.finish().unwrap();
2106 }
2107
2108 let cursor = std::io::Cursor::new(&output);
2110 let mut reader = HctReader::new(cursor).unwrap();
2111
2112 assert_eq!(reader.header().num_blocks, 0);
2113 assert_eq!(reader.header().original_size, 0);
2114 }
2115
2116 #[test]
2117 fn test_reader_with_completely_invalid_data() {
2118 let garbage = vec![0xDE, 0xAD, 0xBE, 0xEF, 0x12, 0x34, 0x56, 0x78];
2120 let cursor = std::io::Cursor::new(garbage);
2121
2122 let result = HctReader::new(cursor);
2123 assert!(result.is_err(), "Should reject invalid data");
2124 }
2125
2126 #[test]
2127 fn test_writer_multiple_compressions() {
2128 let mut output = Vec::new();
2130 let data = vec![1u8; 100];
2131
2132 let cursor = std::io::Cursor::new(&mut output);
2133 let mut writer = HctWriter::new(cursor, CompressionAlgorithm::Lz4, DType::F32, vec![100])
2134 .with_block_size(64);
2135
2136 let codec = haagenti_lz4::Lz4Codec::new();
2137 writer.compress_data(&data, &codec).unwrap();
2138 writer.finish().unwrap();
2140
2141 let cursor = std::io::Cursor::new(&output);
2143 let reader = HctReader::new(cursor);
2144 assert!(
2145 reader.is_ok(),
2146 "Should be able to read back compressed data"
2147 );
2148 }
2149
2150 #[test]
2151 fn test_block_boundary_at_exact_size() {
2152 let block_size = 256;
2154 let data = vec![0xABu8; block_size];
2155 let mut output = Vec::new();
2156
2157 {
2158 let cursor = std::io::Cursor::new(&mut output);
2159 let mut writer = HctWriter::new(
2160 cursor,
2161 CompressionAlgorithm::Lz4,
2162 DType::F32,
2163 vec![block_size as u64],
2164 )
2165 .with_block_size(block_size as u32);
2166
2167 let codec = haagenti_lz4::Lz4Codec::new();
2168 writer.compress_data(&data, &codec).unwrap();
2169 writer.finish().unwrap();
2170 }
2171
2172 let cursor = std::io::Cursor::new(&output);
2173 let reader = HctReader::new(cursor).unwrap();
2174
2175 assert_eq!(reader.header().num_blocks, 1);
2177 }
2178
2179 #[test]
2180 fn test_block_boundary_at_size_plus_one() {
2181 let block_size = 256;
2183 let data = vec![0xCDu8; block_size + 1];
2184 let mut output = Vec::new();
2185
2186 {
2187 let cursor = std::io::Cursor::new(&mut output);
2188 let mut writer = HctWriter::new(
2189 cursor,
2190 CompressionAlgorithm::Lz4,
2191 DType::F32,
2192 vec![(block_size + 1) as u64],
2193 )
2194 .with_block_size(block_size as u32);
2195
2196 let codec = haagenti_lz4::Lz4Codec::new();
2197 writer.compress_data(&data, &codec).unwrap();
2198 writer.finish().unwrap();
2199 }
2200
2201 let cursor = std::io::Cursor::new(&output);
2202 let reader = HctReader::new(cursor).unwrap();
2203
2204 assert_eq!(reader.header().num_blocks, 2);
2206 }
2207
2208 #[test]
2209 fn test_decompression_with_wrong_algorithm() {
2210 let data = vec![0x12u8; 512];
2212 let mut output = Vec::new();
2213
2214 {
2215 let cursor = std::io::Cursor::new(&mut output);
2216 let mut writer =
2217 HctWriter::new(cursor, CompressionAlgorithm::Lz4, DType::F32, vec![512])
2218 .with_block_size(256);
2219
2220 let codec = haagenti_lz4::Lz4Codec::new();
2221 writer.compress_data(&data, &codec).unwrap();
2222 writer.finish().unwrap();
2223 }
2224
2225 let cursor = std::io::Cursor::new(&output);
2227 let mut reader = HctReader::new(cursor).unwrap();
2228
2229 assert_eq!(reader.header().algorithm, CompressionAlgorithm::Lz4);
2231
2232 let lz4 = haagenti_lz4::Lz4Codec::new();
2234 let result = reader.decompress_all(&lz4);
2235 assert!(
2236 result.is_ok(),
2237 "Decompression with correct algorithm should work"
2238 );
2239 assert_eq!(result.unwrap(), data);
2240 }
2241}