Skip to main content

haagenti_hct/
tensor.rs

1//! Compressed Tensor Format (.hct) for LLM weight storage.
2//!
3//! The Haagenti Compressed Tensor format stores quantized model weights
4//! with block-level compression for efficient random access and parallel
5//! decompression.
6//!
7//! ## Format Overview
8//!
9//! ```text
10//! ┌────────────────────────────────────────────────────────────┐
11//! │ Header (64 bytes)                                          │
12//! │  - Magic: "HCTN" (4 bytes)                                 │
13//! │  - Version: u32                                            │
14//! │  - Algorithm: u8 (0=LZ4, 1=Zstd)                           │
15//! │  - Dtype: u8 (0=F32, 1=F16, 2=BF16, 3=I8, 4=I4)           │
16//! │  - Flags: u16                                              │
17//! │  - Original size: u64                                      │
18//! │  - Compressed size: u64                                    │
19//! │  - Block size: u32                                         │
20//! │  - Num blocks: u32                                         │
21//! │  - Shape rank: u8                                          │
22//! │  - Shape dims: [u64; 4]                                    │
23//! │  - Reserved: padding to 64 bytes                           │
24//! ├────────────────────────────────────────────────────────────┤
25//! │ Block Index (num_blocks * 8 bytes)                         │
26//! │  - For each block:                                         │
27//! │    - Offset from data start: u32                           │
28//! │    - Compressed size: u32                                  │
29//! ├────────────────────────────────────────────────────────────┤
30//! │ Compressed Data                                            │
31//! │  - Block 0: [compressed bytes]                             │
32//! │  - Block 1: [compressed bytes]                             │
33//! │  - ...                                                     │
34//! └────────────────────────────────────────────────────────────┘
35//! ```
36
37use std::fs::File;
38use std::io::{Read, Seek, SeekFrom, Write};
39use std::path::Path;
40
41use haagenti_core::{Compressor, Decompressor, Error, Result};
42use xxhash_rust::xxh3::xxh3_64;
43
44/// Magic bytes for the HCT format.
45pub const HCT_MAGIC: [u8; 4] = *b"HCTN";
46
47/// Format version 1 (original).
48pub const HCT_VERSION: u32 = 1;
49
50/// Format version 2 (with checksums and quantization metadata).
51pub const HCT_VERSION_V2: u32 = 2;
52
53// ==================== HCT v2 Flags ====================
54
55/// Flag: Header checksum present (XXH3-64).
56pub const FLAG_HEADER_CHECKSUM: u16 = 0x0001;
57
58/// Flag: Per-block checksums present (XXH3-64 for each block).
59pub const FLAG_BLOCK_CHECKSUMS: u16 = 0x0002;
60
61/// Flag: Quantization metadata present.
62pub const FLAG_QUANTIZATION: u16 = 0x0004;
63
64/// Flag: Tensor name embedded in extended header.
65pub const FLAG_TENSOR_NAME: u16 = 0x0008;
66
67/// Flag: Holographic encoded data (HoloTensor format).
68/// When set, the HCT file contains holographic fragments instead of raw blocks.
69/// The fragment data follows the HoloTensorHeader structure.
70pub const FLAG_HOLOGRAPHIC: u16 = 0x0010;
71
72/// Default block size (16 KB uncompressed).
73/// Note: 16KB chosen for compatibility with haagenti-zstd which has issues at larger sizes
74pub const DEFAULT_BLOCK_SIZE: u32 = 16 * 1024;
75
76/// Compression algorithm identifier.
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78#[repr(u8)]
79pub enum CompressionAlgorithm {
80    Lz4 = 0,
81    Zstd = 1,
82}
83
84impl TryFrom<u8> for CompressionAlgorithm {
85    type Error = Error;
86
87    fn try_from(value: u8) -> Result<Self> {
88        match value {
89            0 => Ok(CompressionAlgorithm::Lz4),
90            1 => Ok(CompressionAlgorithm::Zstd),
91            _ => Err(Error::corrupted(format!("unknown algorithm: {}", value))),
92        }
93    }
94}
95
96/// Data type identifier.
97#[derive(Debug, Clone, Copy, PartialEq, Eq)]
98#[repr(u8)]
99pub enum DType {
100    F32 = 0,
101    F16 = 1,
102    BF16 = 2,
103    I8 = 3,
104    I4 = 4,
105}
106
107impl DType {
108    /// Returns the size in bits.
109    pub fn bits(&self) -> usize {
110        match self {
111            DType::F32 => 32,
112            DType::F16 | DType::BF16 => 16,
113            DType::I8 => 8,
114            DType::I4 => 4,
115        }
116    }
117
118    /// Returns the size in bytes (rounded up for sub-byte types).
119    pub fn bytes(&self) -> usize {
120        self.bits().div_ceil(8)
121    }
122}
123
124impl TryFrom<u8> for DType {
125    type Error = Error;
126
127    fn try_from(value: u8) -> Result<Self> {
128        match value {
129            0 => Ok(DType::F32),
130            1 => Ok(DType::F16),
131            2 => Ok(DType::BF16),
132            3 => Ok(DType::I8),
133            4 => Ok(DType::I4),
134            _ => Err(Error::corrupted(format!("unknown dtype: {}", value))),
135        }
136    }
137}
138
139// ==================== Quantization Metadata (v2) ====================
140
141/// Quantization scheme identifier.
142#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
143#[repr(u8)]
144pub enum QuantizationScheme {
145    /// No quantization (full precision).
146    #[default]
147    None = 0,
148    /// GPTQ-style INT4 quantization.
149    GptqInt4 = 1,
150    /// AWQ-style INT4 quantization.
151    AwqInt4 = 2,
152    /// Symmetric INT8 quantization.
153    SymmetricInt8 = 3,
154    /// Asymmetric INT8 quantization.
155    AsymmetricInt8 = 4,
156}
157
158impl TryFrom<u8> for QuantizationScheme {
159    type Error = Error;
160
161    fn try_from(value: u8) -> Result<Self> {
162        match value {
163            0 => Ok(QuantizationScheme::None),
164            1 => Ok(QuantizationScheme::GptqInt4),
165            2 => Ok(QuantizationScheme::AwqInt4),
166            3 => Ok(QuantizationScheme::SymmetricInt8),
167            4 => Ok(QuantizationScheme::AsymmetricInt8),
168            _ => Err(Error::corrupted(format!(
169                "unknown quantization scheme: {}",
170                value
171            ))),
172        }
173    }
174}
175
176/// Quantization metadata for HCT v2.
177///
178/// Contains information needed to dequantize INT4/INT8 weights.
179#[derive(Debug, Clone, Default, PartialEq)]
180pub struct QuantizationMetadata {
181    /// Quantization scheme used.
182    pub scheme: QuantizationScheme,
183    /// Group size for group-wise quantization (0 = per-tensor).
184    pub group_size: u32,
185    /// Global scale factor (f16 stored as u16 bits).
186    pub scale_bits: u16,
187    /// Global zero point (for asymmetric quantization).
188    pub zero_point: i8,
189    /// Whether per-group scales are stored after compressed data.
190    pub has_per_group_scales: bool,
191}
192
193impl QuantizationMetadata {
194    /// Size of quantization metadata in bytes.
195    pub const SIZE: usize = 8;
196
197    /// Serialize to bytes.
198    pub fn to_bytes(&self) -> [u8; Self::SIZE] {
199        let mut buf = [0u8; Self::SIZE];
200        buf[0] = self.scheme as u8;
201        buf[1] = if self.has_per_group_scales { 1 } else { 0 };
202        buf[2..4].copy_from_slice(&self.scale_bits.to_le_bytes());
203        buf[4] = self.zero_point as u8;
204        buf[5..8].copy_from_slice(&self.group_size.to_le_bytes()[..3]);
205        buf
206    }
207
208    /// Parse from bytes.
209    pub fn from_bytes(buf: &[u8; Self::SIZE]) -> Result<Self> {
210        let scheme = QuantizationScheme::try_from(buf[0])?;
211        let has_per_group_scales = buf[1] != 0;
212        let scale_bits = u16::from_le_bytes([buf[2], buf[3]]);
213        let zero_point = buf[4] as i8;
214        let mut group_size_bytes = [0u8; 4];
215        group_size_bytes[..3].copy_from_slice(&buf[5..8]);
216        let group_size = u32::from_le_bytes(group_size_bytes);
217
218        Ok(Self {
219            scheme,
220            group_size,
221            scale_bits,
222            zero_point,
223            has_per_group_scales,
224        })
225    }
226}
227
228// ==================== Block Index with Checksum (v2) ====================
229
230/// Block index entry with optional checksum for v2.
231#[derive(Debug, Clone, Copy)]
232pub struct BlockIndexV2 {
233    /// Offset from the start of compressed data.
234    pub offset: u32,
235    /// Compressed size of this block.
236    pub compressed_size: u32,
237    /// XXH3-64 checksum of compressed data (0 if not computed).
238    pub checksum: u64,
239}
240
241impl BlockIndexV2 {
242    /// Size of a v2 block index entry in bytes.
243    pub const SIZE: usize = 16;
244
245    /// Serialize to bytes.
246    pub fn to_bytes(&self) -> [u8; Self::SIZE] {
247        let mut buf = [0u8; Self::SIZE];
248        buf[0..4].copy_from_slice(&self.offset.to_le_bytes());
249        buf[4..8].copy_from_slice(&self.compressed_size.to_le_bytes());
250        buf[8..16].copy_from_slice(&self.checksum.to_le_bytes());
251        buf
252    }
253
254    /// Parse from bytes.
255    pub fn from_bytes(buf: &[u8; Self::SIZE]) -> Self {
256        Self {
257            offset: u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]),
258            compressed_size: u32::from_le_bytes([buf[4], buf[5], buf[6], buf[7]]),
259            checksum: u64::from_le_bytes(buf[8..16].try_into().unwrap()),
260        }
261    }
262
263    /// Create from v1 block index (no checksum).
264    pub fn from_v1(v1: BlockIndex) -> Self {
265        Self {
266            offset: v1.offset,
267            compressed_size: v1.compressed_size,
268            checksum: 0,
269        }
270    }
271}
272
273/// Header for the compressed tensor format.
274#[derive(Debug, Clone)]
275pub struct HctHeader {
276    /// Compression algorithm used.
277    pub algorithm: CompressionAlgorithm,
278    /// Data type of the tensor.
279    pub dtype: DType,
280    /// Flags (reserved for future use).
281    pub flags: u16,
282    /// Original uncompressed size in bytes.
283    pub original_size: u64,
284    /// Total compressed size in bytes (excluding header and index).
285    pub compressed_size: u64,
286    /// Block size for compression (uncompressed).
287    pub block_size: u32,
288    /// Number of compressed blocks.
289    pub num_blocks: u32,
290    /// Tensor shape.
291    pub shape: Vec<u64>,
292}
293
294impl HctHeader {
295    /// Header size in bytes.
296    pub const SIZE: usize = 64;
297
298    /// Serialize header to bytes.
299    pub fn to_bytes(&self) -> [u8; Self::SIZE] {
300        let mut buf = [0u8; Self::SIZE];
301
302        // Magic
303        buf[0..4].copy_from_slice(&HCT_MAGIC);
304
305        // Version
306        buf[4..8].copy_from_slice(&HCT_VERSION.to_le_bytes());
307
308        // Algorithm and dtype
309        buf[8] = self.algorithm as u8;
310        buf[9] = self.dtype as u8;
311
312        // Flags
313        buf[10..12].copy_from_slice(&self.flags.to_le_bytes());
314
315        // Sizes
316        buf[12..20].copy_from_slice(&self.original_size.to_le_bytes());
317        buf[20..28].copy_from_slice(&self.compressed_size.to_le_bytes());
318        buf[28..32].copy_from_slice(&self.block_size.to_le_bytes());
319        buf[32..36].copy_from_slice(&self.num_blocks.to_le_bytes());
320
321        // Shape
322        buf[36] = self.shape.len() as u8;
323        for (i, &dim) in self.shape.iter().take(4).enumerate() {
324            let offset = 37 + i * 8;
325            buf[offset..offset + 8].copy_from_slice(&dim.to_le_bytes());
326        }
327
328        buf
329    }
330
331    /// Parse header from bytes.
332    pub fn from_bytes(buf: &[u8; Self::SIZE]) -> Result<Self> {
333        // Validate magic
334        if buf[0..4] != HCT_MAGIC {
335            return Err(Error::corrupted("invalid HCT magic"));
336        }
337
338        // Validate version (accept v1 or v2)
339        let version = u32::from_le_bytes([buf[4], buf[5], buf[6], buf[7]]);
340        if version > HCT_VERSION_V2 {
341            return Err(Error::corrupted(format!(
342                "unsupported HCT version: {} (max: {})",
343                version, HCT_VERSION_V2
344            )));
345        }
346
347        let algorithm = CompressionAlgorithm::try_from(buf[8])?;
348        let dtype = DType::try_from(buf[9])?;
349        let flags = u16::from_le_bytes([buf[10], buf[11]]);
350
351        let original_size = u64::from_le_bytes(buf[12..20].try_into().unwrap());
352        let compressed_size = u64::from_le_bytes(buf[20..28].try_into().unwrap());
353        let block_size = u32::from_le_bytes(buf[28..32].try_into().unwrap());
354        let num_blocks = u32::from_le_bytes(buf[32..36].try_into().unwrap());
355
356        let rank = buf[36] as usize;
357        let mut shape = Vec::with_capacity(rank);
358        for i in 0..rank.min(4) {
359            let offset = 37 + i * 8;
360            let dim = u64::from_le_bytes(buf[offset..offset + 8].try_into().unwrap());
361            shape.push(dim);
362        }
363
364        Ok(Self {
365            algorithm,
366            dtype,
367            flags,
368            original_size,
369            compressed_size,
370            block_size,
371            num_blocks,
372            shape,
373        })
374    }
375}
376
377/// Block index entry.
378#[derive(Debug, Clone, Copy)]
379pub struct BlockIndex {
380    /// Offset from the start of compressed data.
381    pub offset: u32,
382    /// Compressed size of this block.
383    pub compressed_size: u32,
384}
385
386impl BlockIndex {
387    /// Size of a block index entry in bytes.
388    pub const SIZE: usize = 8;
389
390    /// Serialize to bytes.
391    pub fn to_bytes(&self) -> [u8; Self::SIZE] {
392        let mut buf = [0u8; Self::SIZE];
393        buf[0..4].copy_from_slice(&self.offset.to_le_bytes());
394        buf[4..8].copy_from_slice(&self.compressed_size.to_le_bytes());
395        buf
396    }
397
398    /// Parse from bytes.
399    pub fn from_bytes(buf: &[u8; Self::SIZE]) -> Self {
400        Self {
401            offset: u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]),
402            compressed_size: u32::from_le_bytes([buf[4], buf[5], buf[6], buf[7]]),
403        }
404    }
405}
406
407/// Reader for compressed tensor files.
408pub struct HctReader<R: Read + Seek> {
409    reader: R,
410    header: HctHeader,
411    block_index: Vec<BlockIndex>,
412    data_offset: u64,
413}
414
415impl<R: Read + Seek> HctReader<R> {
416    /// Open an HCT file for reading.
417    pub fn new(mut reader: R) -> Result<Self> {
418        // Read header
419        let mut header_buf = [0u8; HctHeader::SIZE];
420        reader
421            .read_exact(&mut header_buf)
422            .map_err(|e| Error::algorithm("hct", format!("failed to read header: {}", e)))?;
423        let header = HctHeader::from_bytes(&header_buf)?;
424
425        // Read block index
426        let index_size = header.num_blocks as usize * BlockIndex::SIZE;
427        let mut index_buf = vec![0u8; index_size];
428        reader
429            .read_exact(&mut index_buf)
430            .map_err(|e| Error::algorithm("hct", format!("failed to read block index: {}", e)))?;
431
432        let block_index: Vec<BlockIndex> = index_buf
433            .chunks_exact(BlockIndex::SIZE)
434            .map(|chunk| BlockIndex::from_bytes(chunk.try_into().unwrap()))
435            .collect();
436
437        let data_offset = HctHeader::SIZE as u64 + index_size as u64;
438
439        Ok(Self {
440            reader,
441            header,
442            block_index,
443            data_offset,
444        })
445    }
446
447    /// Get the header.
448    pub fn header(&self) -> &HctHeader {
449        &self.header
450    }
451
452    /// Get the number of blocks.
453    pub fn num_blocks(&self) -> usize {
454        self.block_index.len()
455    }
456
457    /// Read a single compressed block.
458    pub fn read_block(&mut self, block_idx: usize) -> Result<Vec<u8>> {
459        if block_idx >= self.block_index.len() {
460            return Err(Error::corrupted(format!(
461                "block index out of range: {} >= {}",
462                block_idx,
463                self.block_index.len()
464            )));
465        }
466
467        let index = &self.block_index[block_idx];
468        let offset = self.data_offset + index.offset as u64;
469
470        self.reader.seek(SeekFrom::Start(offset)).map_err(|e| {
471            Error::algorithm(
472                "hct",
473                format!("failed to seek to block {}: {}", block_idx, e),
474            )
475        })?;
476
477        let mut buf = vec![0u8; index.compressed_size as usize];
478        self.reader.read_exact(&mut buf).map_err(|e| {
479            Error::algorithm("hct", format!("failed to read block {}: {}", block_idx, e))
480        })?;
481
482        Ok(buf)
483    }
484
485    /// Decompress a single block using the provided decompressor.
486    pub fn decompress_block(
487        &mut self,
488        block_idx: usize,
489        decompressor: &impl Decompressor,
490    ) -> Result<Vec<u8>> {
491        let compressed = self.read_block(block_idx)?;
492
493        // Calculate expected decompressed size
494        let is_last_block = block_idx == self.block_index.len() - 1;
495        let expected_size = if is_last_block {
496            let full_blocks = (self.block_index.len() - 1) as u64 * self.header.block_size as u64;
497            (self.header.original_size - full_blocks) as usize
498        } else {
499            self.header.block_size as usize
500        };
501
502        decompressor.decompress_with_size(&compressed, expected_size)
503    }
504
505    /// Decompress all blocks into a contiguous buffer.
506    pub fn decompress_all(&mut self, decompressor: &impl Decompressor) -> Result<Vec<u8>> {
507        let mut output = Vec::with_capacity(self.header.original_size as usize);
508
509        for block_idx in 0..self.block_index.len() {
510            let decompressed = self.decompress_block(block_idx, decompressor)?;
511            output.extend_from_slice(&decompressed);
512        }
513
514        Ok(output)
515    }
516}
517
518/// Writer for compressed tensor files.
519pub struct HctWriter<W: Write + Seek> {
520    writer: W,
521    algorithm: CompressionAlgorithm,
522    dtype: DType,
523    block_size: u32,
524    shape: Vec<u64>,
525    blocks: Vec<Vec<u8>>,
526    original_size: u64,
527}
528
529impl<W: Write + Seek> HctWriter<W> {
530    /// Create a new HCT writer.
531    pub fn new(writer: W, algorithm: CompressionAlgorithm, dtype: DType, shape: Vec<u64>) -> Self {
532        Self {
533            writer,
534            algorithm,
535            dtype,
536            block_size: DEFAULT_BLOCK_SIZE,
537            shape,
538            blocks: Vec::new(),
539            original_size: 0,
540        }
541    }
542
543    /// Set the block size.
544    pub fn with_block_size(mut self, block_size: u32) -> Self {
545        self.block_size = block_size;
546        self
547    }
548
549    /// Add compressed data for a block.
550    pub fn add_block(&mut self, compressed: Vec<u8>, original_len: usize) {
551        self.blocks.push(compressed);
552        self.original_size += original_len as u64;
553    }
554
555    /// Compress data and add blocks.
556    pub fn compress_data(&mut self, data: &[u8], compressor: &impl Compressor) -> Result<()> {
557        for chunk in data.chunks(self.block_size as usize) {
558            let compressed = compressor.compress(chunk)?;
559            self.add_block(compressed, chunk.len());
560        }
561        Ok(())
562    }
563
564    /// Finalize and write the file.
565    pub fn finish(mut self) -> Result<()> {
566        // Calculate compressed size and build index
567        let mut block_index = Vec::with_capacity(self.blocks.len());
568        let mut offset = 0u32;
569
570        for block in &self.blocks {
571            block_index.push(BlockIndex {
572                offset,
573                compressed_size: block.len() as u32,
574            });
575            offset += block.len() as u32;
576        }
577
578        let compressed_size = offset as u64;
579
580        // Build header
581        let header = HctHeader {
582            algorithm: self.algorithm,
583            dtype: self.dtype,
584            flags: 0,
585            original_size: self.original_size,
586            compressed_size,
587            block_size: self.block_size,
588            num_blocks: self.blocks.len() as u32,
589            shape: self.shape,
590        };
591
592        // Write header
593        self.writer
594            .write_all(&header.to_bytes())
595            .map_err(|e| Error::algorithm("hct", format!("failed to write header: {}", e)))?;
596
597        // Write block index
598        for index in &block_index {
599            self.writer.write_all(&index.to_bytes()).map_err(|e| {
600                Error::algorithm("hct", format!("failed to write block index: {}", e))
601            })?;
602        }
603
604        // Write compressed data
605        for block in &self.blocks {
606            self.writer.write_all(block).map_err(|e| {
607                Error::algorithm("hct", format!("failed to write block data: {}", e))
608            })?;
609        }
610
611        self.writer
612            .flush()
613            .map_err(|e| Error::algorithm("hct", format!("failed to flush: {}", e)))?;
614
615        Ok(())
616    }
617}
618
619/// Compress a tensor file to HCT format.
620pub fn compress_file(
621    input_path: impl AsRef<Path>,
622    output_path: impl AsRef<Path>,
623    compressor: &impl Compressor,
624    dtype: DType,
625    shape: Vec<u64>,
626) -> Result<CompressionStats> {
627    use std::time::Instant;
628
629    let start = Instant::now();
630
631    // Read input
632    let input_data = std::fs::read(input_path.as_ref())
633        .map_err(|e| Error::algorithm("hct", format!("failed to read input file: {}", e)))?;
634    let original_size = input_data.len();
635
636    // Create output file
637    let output_file = File::create(output_path.as_ref())
638        .map_err(|e| Error::algorithm("hct", format!("failed to create output file: {}", e)))?;
639
640    // Determine algorithm
641    let algorithm = match compressor.algorithm() {
642        haagenti_core::Algorithm::Lz4 => CompressionAlgorithm::Lz4,
643        haagenti_core::Algorithm::Zstd => CompressionAlgorithm::Zstd,
644        _ => return Err(Error::corrupted("unsupported algorithm for HCT")),
645    };
646
647    // Compress
648    let mut writer = HctWriter::new(output_file, algorithm, dtype, shape);
649    writer.compress_data(&input_data, compressor)?;
650    writer.finish()?;
651
652    // Get output size
653    let output_metadata = std::fs::metadata(output_path.as_ref())
654        .map_err(|e| Error::algorithm("hct", format!("failed to get output metadata: {}", e)))?;
655    let compressed_size = output_metadata.len() as usize;
656
657    let elapsed = start.elapsed();
658
659    Ok(CompressionStats {
660        original_size,
661        compressed_size,
662        ratio: original_size as f64 / compressed_size as f64,
663        elapsed_ms: elapsed.as_millis() as u64,
664    })
665}
666
667/// Statistics from compression.
668#[derive(Debug, Clone)]
669pub struct CompressionStats {
670    pub original_size: usize,
671    pub compressed_size: usize,
672    pub ratio: f64,
673    pub elapsed_ms: u64,
674}
675
676// ==================== HCT v2 Writer and Reader ====================
677
678/// Writer for HCT v2 format with checksum and quantization support.
679pub struct HctWriterV2<W: Write + Seek> {
680    writer: W,
681    algorithm: CompressionAlgorithm,
682    dtype: DType,
683    block_size: u32,
684    shape: Vec<u64>,
685    blocks: Vec<(Vec<u8>, u64)>, // (compressed_data, checksum)
686    original_size: u64,
687    flags: u16,
688    quantization: Option<QuantizationMetadata>,
689}
690
691impl<W: Write + Seek> HctWriterV2<W> {
692    /// Create a new HCT v2 writer with checksums enabled.
693    pub fn new(writer: W, algorithm: CompressionAlgorithm, dtype: DType, shape: Vec<u64>) -> Self {
694        Self {
695            writer,
696            algorithm,
697            dtype,
698            block_size: DEFAULT_BLOCK_SIZE,
699            shape,
700            blocks: Vec::new(),
701            original_size: 0,
702            flags: FLAG_HEADER_CHECKSUM | FLAG_BLOCK_CHECKSUMS,
703            quantization: None,
704        }
705    }
706
707    /// Set the block size.
708    pub fn with_block_size(mut self, block_size: u32) -> Self {
709        self.block_size = block_size;
710        self
711    }
712
713    /// Add quantization metadata.
714    pub fn with_quantization(mut self, quant: QuantizationMetadata) -> Self {
715        self.quantization = Some(quant);
716        self.flags |= FLAG_QUANTIZATION;
717        self
718    }
719
720    /// Disable block checksums (for performance).
721    pub fn without_block_checksums(mut self) -> Self {
722        self.flags &= !FLAG_BLOCK_CHECKSUMS;
723        self
724    }
725
726    /// Add compressed data for a block with checksum.
727    pub fn add_block(&mut self, compressed: Vec<u8>, original_len: usize) {
728        let checksum = if self.flags & FLAG_BLOCK_CHECKSUMS != 0 {
729            xxh3_64(&compressed)
730        } else {
731            0
732        };
733        self.blocks.push((compressed, checksum));
734        self.original_size += original_len as u64;
735    }
736
737    /// Compress data and add blocks.
738    pub fn compress_data(&mut self, data: &[u8], compressor: &impl Compressor) -> Result<()> {
739        for chunk in data.chunks(self.block_size as usize) {
740            let compressed = compressor.compress(chunk)?;
741            self.add_block(compressed, chunk.len());
742        }
743        Ok(())
744    }
745
746    /// Finalize and write the v2 file.
747    pub fn finish(mut self) -> Result<()> {
748        // Calculate compressed size and build v2 index
749        let mut block_index = Vec::with_capacity(self.blocks.len());
750        let mut offset = 0u32;
751
752        for (block, checksum) in &self.blocks {
753            block_index.push(BlockIndexV2 {
754                offset,
755                compressed_size: block.len() as u32,
756                checksum: *checksum,
757            });
758            offset += block.len() as u32;
759        }
760
761        let compressed_size = offset as u64;
762
763        // Build v1-compatible header (with v2 version and flags)
764        let mut header_bytes = [0u8; HctHeader::SIZE];
765
766        // Magic
767        header_bytes[0..4].copy_from_slice(&HCT_MAGIC);
768
769        // Version = 2
770        header_bytes[4..8].copy_from_slice(&HCT_VERSION_V2.to_le_bytes());
771
772        // Algorithm and dtype
773        header_bytes[8] = self.algorithm as u8;
774        header_bytes[9] = self.dtype as u8;
775
776        // Flags (with v2 flags set)
777        header_bytes[10..12].copy_from_slice(&self.flags.to_le_bytes());
778
779        // Sizes
780        header_bytes[12..20].copy_from_slice(&self.original_size.to_le_bytes());
781        header_bytes[20..28].copy_from_slice(&compressed_size.to_le_bytes());
782        header_bytes[28..32].copy_from_slice(&self.block_size.to_le_bytes());
783        header_bytes[32..36].copy_from_slice(&(self.blocks.len() as u32).to_le_bytes());
784
785        // Shape
786        header_bytes[36] = self.shape.len() as u8;
787        for (i, &dim) in self.shape.iter().take(4).enumerate() {
788            let off = 37 + i * 8;
789            header_bytes[off..off + 8].copy_from_slice(&dim.to_le_bytes());
790        }
791
792        // Compute header checksum (over header bytes, excluding the checksum itself)
793        // We'll store checksum in the unused bytes at the end of header
794        let header_checksum = xxh3_64(&header_bytes[..56]); // First 56 bytes
795        header_bytes[56..64].copy_from_slice(&header_checksum.to_le_bytes());
796
797        // Write header
798        self.writer
799            .write_all(&header_bytes)
800            .map_err(|e| Error::algorithm("hct", format!("failed to write header: {}", e)))?;
801
802        // Write quantization metadata if present
803        if let Some(ref quant) = self.quantization {
804            self.writer.write_all(&quant.to_bytes()).map_err(|e| {
805                Error::algorithm("hct", format!("failed to write quantization: {}", e))
806            })?;
807        }
808
809        // Write v2 block index (with checksums)
810        for index in &block_index {
811            self.writer.write_all(&index.to_bytes()).map_err(|e| {
812                Error::algorithm("hct", format!("failed to write block index: {}", e))
813            })?;
814        }
815
816        // Write compressed data
817        for (block, _) in &self.blocks {
818            self.writer.write_all(block).map_err(|e| {
819                Error::algorithm("hct", format!("failed to write block data: {}", e))
820            })?;
821        }
822
823        self.writer
824            .flush()
825            .map_err(|e| Error::algorithm("hct", format!("failed to flush: {}", e)))?;
826
827        Ok(())
828    }
829}
830
831/// Reader for HCT v2 files with checksum validation.
832pub struct HctReaderV2<R: Read + Seek> {
833    reader: R,
834    header: HctHeader,
835    block_index: Vec<BlockIndexV2>,
836    data_offset: u64,
837    quantization: Option<QuantizationMetadata>,
838}
839
840impl<R: Read + Seek> HctReaderV2<R> {
841    /// Open an HCT v2 file for reading.
842    pub fn new(mut reader: R) -> Result<Self> {
843        // Read header
844        let mut header_buf = [0u8; HctHeader::SIZE];
845        reader
846            .read_exact(&mut header_buf)
847            .map_err(|e| Error::algorithm("hct", format!("failed to read header: {}", e)))?;
848
849        // Parse basic header
850        let header = HctHeader::from_bytes(&header_buf)?;
851
852        // Extract stored header checksum (last 8 bytes of header)
853        let stored_checksum = u64::from_le_bytes(header_buf[56..64].try_into().unwrap());
854
855        // Verify header checksum if v2
856        let version = u32::from_le_bytes(header_buf[4..8].try_into().unwrap());
857        if version >= HCT_VERSION_V2 && header.flags & FLAG_HEADER_CHECKSUM != 0 {
858            let computed = xxh3_64(&header_buf[..56]);
859            if computed != stored_checksum {
860                return Err(Error::corrupted(format!(
861                    "header checksum mismatch: expected {:016x}, got {:016x}",
862                    stored_checksum, computed
863                )));
864            }
865        }
866
867        // Read quantization metadata if present
868        let quantization = if header.flags & FLAG_QUANTIZATION != 0 {
869            let mut quant_buf = [0u8; QuantizationMetadata::SIZE];
870            reader.read_exact(&mut quant_buf).map_err(|e| {
871                Error::algorithm("hct", format!("failed to read quantization: {}", e))
872            })?;
873            Some(QuantizationMetadata::from_bytes(&quant_buf)?)
874        } else {
875            None
876        };
877
878        // Determine index entry size based on version
879        let index_entry_size =
880            if version >= HCT_VERSION_V2 && header.flags & FLAG_BLOCK_CHECKSUMS != 0 {
881                BlockIndexV2::SIZE
882            } else {
883                BlockIndex::SIZE
884            };
885
886        // Read block index
887        let index_size = header.num_blocks as usize * index_entry_size;
888        let mut index_buf = vec![0u8; index_size];
889        reader
890            .read_exact(&mut index_buf)
891            .map_err(|e| Error::algorithm("hct", format!("failed to read block index: {}", e)))?;
892
893        let block_index: Vec<BlockIndexV2> = if index_entry_size == BlockIndexV2::SIZE {
894            index_buf
895                .chunks_exact(BlockIndexV2::SIZE)
896                .map(|chunk| BlockIndexV2::from_bytes(chunk.try_into().unwrap()))
897                .collect()
898        } else {
899            // Convert v1 index to v2 (no checksums)
900            index_buf
901                .chunks_exact(BlockIndex::SIZE)
902                .map(|chunk| {
903                    let v1 = BlockIndex::from_bytes(chunk.try_into().unwrap());
904                    BlockIndexV2::from_v1(v1)
905                })
906                .collect()
907        };
908
909        let quant_size = if quantization.is_some() {
910            QuantizationMetadata::SIZE
911        } else {
912            0
913        };
914        let data_offset = HctHeader::SIZE as u64 + quant_size as u64 + index_size as u64;
915
916        Ok(Self {
917            reader,
918            header,
919            block_index,
920            data_offset,
921            quantization,
922        })
923    }
924
925    /// Get the header.
926    pub fn header(&self) -> &HctHeader {
927        &self.header
928    }
929
930    /// Get quantization metadata if present.
931    pub fn quantization(&self) -> Option<&QuantizationMetadata> {
932        self.quantization.as_ref()
933    }
934
935    /// Get the number of blocks.
936    pub fn num_blocks(&self) -> usize {
937        self.block_index.len()
938    }
939
940    /// Read and validate a single compressed block.
941    pub fn read_block_validated(&mut self, block_idx: usize) -> Result<Vec<u8>> {
942        if block_idx >= self.block_index.len() {
943            return Err(Error::corrupted(format!(
944                "block index out of range: {} >= {}",
945                block_idx,
946                self.block_index.len()
947            )));
948        }
949
950        let index = &self.block_index[block_idx];
951        let offset = self.data_offset + index.offset as u64;
952
953        self.reader.seek(SeekFrom::Start(offset)).map_err(|e| {
954            Error::algorithm(
955                "hct",
956                format!("failed to seek to block {}: {}", block_idx, e),
957            )
958        })?;
959
960        let mut buf = vec![0u8; index.compressed_size as usize];
961        self.reader.read_exact(&mut buf).map_err(|e| {
962            Error::algorithm("hct", format!("failed to read block {}: {}", block_idx, e))
963        })?;
964
965        // Validate checksum if present
966        if index.checksum != 0 {
967            let computed = xxh3_64(&buf);
968            if computed != index.checksum {
969                return Err(Error::corrupted(format!(
970                    "block {} checksum mismatch: expected {:016x}, got {:016x}",
971                    block_idx, index.checksum, computed
972                )));
973            }
974        }
975
976        Ok(buf)
977    }
978
979    /// Decompress a single block with validation.
980    pub fn decompress_block_validated(
981        &mut self,
982        block_idx: usize,
983        decompressor: &impl Decompressor,
984    ) -> Result<Vec<u8>> {
985        let compressed = self.read_block_validated(block_idx)?;
986
987        // Calculate expected decompressed size
988        let is_last_block = block_idx == self.block_index.len() - 1;
989        let expected_size = if is_last_block {
990            let full_blocks = (self.block_index.len() - 1) as u64 * self.header.block_size as u64;
991            (self.header.original_size - full_blocks) as usize
992        } else {
993            self.header.block_size as usize
994        };
995
996        decompressor.decompress_with_size(&compressed, expected_size)
997    }
998
999    /// Decompress all blocks with validation.
1000    pub fn decompress_all_validated(
1001        &mut self,
1002        decompressor: &impl Decompressor,
1003    ) -> Result<Vec<u8>> {
1004        let mut output = Vec::with_capacity(self.header.original_size as usize);
1005
1006        for block_idx in 0..self.block_index.len() {
1007            let decompressed = self.decompress_block_validated(block_idx, decompressor)?;
1008            output.extend_from_slice(&decompressed);
1009        }
1010
1011        Ok(output)
1012    }
1013
1014    /// Validate all block checksums without decompressing.
1015    pub fn validate_checksums(&mut self) -> Result<()> {
1016        for block_idx in 0..self.block_index.len() {
1017            let _ = self.read_block_validated(block_idx)?;
1018        }
1019        Ok(())
1020    }
1021}
1022
1023/// Checksum validation error.
1024#[derive(Debug, Clone)]
1025pub struct ChecksumError {
1026    /// Expected checksum.
1027    pub expected: u64,
1028    /// Actual computed checksum.
1029    pub actual: u64,
1030    /// Block index (None for header).
1031    pub block_index: Option<usize>,
1032}
1033
1034impl std::fmt::Display for ChecksumError {
1035    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1036        match self.block_index {
1037            Some(idx) => write!(
1038                f,
1039                "block {} checksum mismatch: expected {:016x}, got {:016x}",
1040                idx, self.expected, self.actual
1041            ),
1042            None => write!(
1043                f,
1044                "header checksum mismatch: expected {:016x}, got {:016x}",
1045                self.expected, self.actual
1046            ),
1047        }
1048    }
1049}
1050
1051#[cfg(test)]
1052mod tests {
1053    use super::*;
1054
1055    #[test]
1056    fn test_header_roundtrip() {
1057        let header = HctHeader {
1058            algorithm: CompressionAlgorithm::Zstd,
1059            dtype: DType::I4,
1060            flags: 0,
1061            original_size: 1024 * 1024,
1062            compressed_size: 256 * 1024,
1063            block_size: 64 * 1024,
1064            num_blocks: 16,
1065            shape: vec![4096, 4096],
1066        };
1067
1068        let bytes = header.to_bytes();
1069        let parsed = HctHeader::from_bytes(&bytes).unwrap();
1070
1071        assert_eq!(parsed.algorithm, header.algorithm);
1072        assert_eq!(parsed.dtype, header.dtype);
1073        assert_eq!(parsed.original_size, header.original_size);
1074        assert_eq!(parsed.compressed_size, header.compressed_size);
1075        assert_eq!(parsed.num_blocks, header.num_blocks);
1076        assert_eq!(parsed.shape, header.shape);
1077    }
1078
1079    #[test]
1080    fn test_block_index_roundtrip() {
1081        let index = BlockIndex {
1082            offset: 12345,
1083            compressed_size: 6789,
1084        };
1085
1086        let bytes = index.to_bytes();
1087        let parsed = BlockIndex::from_bytes(&bytes);
1088
1089        assert_eq!(parsed.offset, index.offset);
1090        assert_eq!(parsed.compressed_size, index.compressed_size);
1091    }
1092
1093    #[test]
1094    #[cfg(feature = "zstd")]
1095    fn test_hct_zstd_roundtrip() {
1096        use haagenti_zstd::{ZstdCompressor, ZstdDecompressor};
1097        use std::io::Cursor;
1098
1099        // Test with 64KB of data using default 32KB block size (2 blocks)
1100        let original_data: Vec<u8> = (0..65536).map(|i| ((i % 256) as i8) as u8).collect();
1101
1102        // Compress to HCT format using default block size
1103        let mut buffer = Vec::new();
1104        {
1105            let cursor = Cursor::new(&mut buffer);
1106            let compressor = ZstdCompressor::new();
1107
1108            let mut writer = HctWriter::new(
1109                cursor,
1110                CompressionAlgorithm::Zstd,
1111                DType::I8,
1112                vec![256, 256],
1113            );
1114            writer.compress_data(&original_data, &compressor).unwrap();
1115            writer.finish().unwrap();
1116        }
1117
1118        // Verify compression worked
1119        assert!(buffer.len() < original_data.len(), "Should compress");
1120        assert!(&buffer[0..4] == &HCT_MAGIC, "Should start with HCT magic");
1121
1122        // Decompress
1123        let cursor = Cursor::new(&buffer);
1124        let mut reader = HctReader::new(cursor).unwrap();
1125        let decompressor = ZstdDecompressor::new();
1126
1127        // Verify we have 4 blocks (64KB / 16KB default)
1128        assert_eq!(reader.num_blocks(), 4, "Should have 4 blocks");
1129
1130        let decompressed = reader.decompress_all(&decompressor).unwrap();
1131        assert_eq!(decompressed, original_data);
1132    }
1133
1134    #[test]
1135    #[cfg(feature = "lz4")]
1136    fn test_hct_lz4_roundtrip() {
1137        use haagenti_lz4::Lz4Compressor;
1138        use std::io::Cursor;
1139
1140        // Create test data with high compressibility (sparse pattern)
1141        let mut original_data = vec![0u8; 65536];
1142        for i in (0..65536).step_by(100) {
1143            original_data[i] = (i % 256) as u8;
1144        }
1145
1146        // Compress to HCT format
1147        let mut buffer = Vec::new();
1148        let cursor = Cursor::new(&mut buffer);
1149        let compressor = Lz4Compressor::new();
1150
1151        let mut writer =
1152            HctWriter::new(cursor, CompressionAlgorithm::Lz4, DType::I8, vec![256, 256]);
1153        writer.compress_data(&original_data, &compressor).unwrap();
1154        writer.finish().unwrap();
1155
1156        assert!(buffer[0..4] == HCT_MAGIC);
1157
1158        // Decompress
1159        use haagenti_lz4::Lz4Decompressor;
1160
1161        let cursor = Cursor::new(&buffer);
1162        let mut reader = HctReader::new(cursor).unwrap();
1163        let decompressor = Lz4Decompressor::new();
1164
1165        // Copy header values before mutable borrow
1166        let algorithm = reader.header().algorithm;
1167        let dtype = reader.header().dtype;
1168        let original_size = reader.header().original_size;
1169        let block_size = reader.header().block_size;
1170        let num_blocks = reader.num_blocks();
1171
1172        assert_eq!(algorithm, CompressionAlgorithm::Lz4);
1173        assert_eq!(dtype, DType::I8);
1174        assert_eq!(original_size, 65536);
1175
1176        // Decompress individual blocks
1177        for i in 0..num_blocks {
1178            let block = reader.decompress_block(i, &decompressor).unwrap();
1179            let expected_len = if i == num_blocks - 1 {
1180                (original_size as usize) % (block_size as usize)
1181            } else {
1182                block_size as usize
1183            };
1184            // Handle edge case where data size is exact multiple of block size
1185            let expected_len = if expected_len == 0 {
1186                block_size as usize
1187            } else {
1188                expected_len
1189            };
1190            assert_eq!(block.len(), expected_len);
1191        }
1192    }
1193
1194    #[test]
1195    #[cfg(feature = "zstd")]
1196    fn test_hct_block_random_access() {
1197        use haagenti_zstd::{ZstdCompressor, ZstdDecompressor};
1198        use std::io::Cursor;
1199
1200        // Create data with distinct patterns per block
1201        let block_size = 1024u32;
1202        let num_blocks = 4usize;
1203        let mut original_data = Vec::new();
1204        for block_idx in 0..num_blocks {
1205            for _ in 0..block_size {
1206                original_data.push(block_idx as u8 * 10);
1207            }
1208        }
1209
1210        // Compress
1211        let mut buffer = Vec::new();
1212        let cursor = Cursor::new(&mut buffer);
1213        let compressor = ZstdCompressor::new();
1214
1215        let mut writer = HctWriter::new(
1216            cursor,
1217            CompressionAlgorithm::Zstd,
1218            DType::I8,
1219            vec![num_blocks as u64, block_size as u64],
1220        )
1221        .with_block_size(block_size);
1222        writer.compress_data(&original_data, &compressor).unwrap();
1223        writer.finish().unwrap();
1224
1225        // Read blocks out of order (random access)
1226        let cursor = Cursor::new(&buffer);
1227        let mut reader = HctReader::new(cursor).unwrap();
1228        let decompressor = ZstdDecompressor::new();
1229
1230        // Read block 2 first
1231        let block2 = reader.decompress_block(2, &decompressor).unwrap();
1232        assert!(block2.iter().all(|&b| b == 20));
1233
1234        // Read block 0
1235        let block0 = reader.decompress_block(0, &decompressor).unwrap();
1236        assert!(block0.iter().all(|&b| b == 0));
1237
1238        // Read block 3
1239        let block3 = reader.decompress_block(3, &decompressor).unwrap();
1240        assert!(block3.iter().all(|&b| b == 30));
1241
1242        // Read block 1
1243        let block1 = reader.decompress_block(1, &decompressor).unwrap();
1244        assert!(block1.iter().all(|&b| b == 10));
1245    }
1246
1247    // ==================== HCT v2 Tests ====================
1248
1249    #[test]
1250    fn test_quantization_metadata_roundtrip() {
1251        let quant = QuantizationMetadata {
1252            scheme: QuantizationScheme::GptqInt4,
1253            group_size: 128,
1254            scale_bits: 0x3C00, // 1.0 in f16
1255            zero_point: -8,
1256            has_per_group_scales: true,
1257        };
1258
1259        let bytes = quant.to_bytes();
1260        let parsed = QuantizationMetadata::from_bytes(&bytes).unwrap();
1261
1262        assert_eq!(parsed.scheme, QuantizationScheme::GptqInt4);
1263        assert_eq!(parsed.group_size, 128);
1264        assert_eq!(parsed.scale_bits, 0x3C00);
1265        assert_eq!(parsed.zero_point, -8);
1266        assert!(parsed.has_per_group_scales);
1267    }
1268
1269    #[test]
1270    fn test_block_index_v2_roundtrip() {
1271        let index = BlockIndexV2 {
1272            offset: 12345,
1273            compressed_size: 6789,
1274            checksum: 0xDEAD_BEEF_CAFE_BABE,
1275        };
1276
1277        let bytes = index.to_bytes();
1278        let parsed = BlockIndexV2::from_bytes(&bytes);
1279
1280        assert_eq!(parsed.offset, index.offset);
1281        assert_eq!(parsed.compressed_size, index.compressed_size);
1282        assert_eq!(parsed.checksum, index.checksum);
1283    }
1284
1285    #[test]
1286    #[cfg(feature = "lz4")]
1287    fn test_hct_v2_checksum_valid() {
1288        use haagenti_lz4::{Lz4Compressor, Lz4Decompressor};
1289        use std::io::Cursor;
1290
1291        // Create test data
1292        let original_data: Vec<u8> = (0..16384).map(|i| (i % 256) as u8).collect();
1293
1294        // Compress with v2 writer
1295        let mut buffer = Vec::new();
1296        {
1297            let cursor = Cursor::new(&mut buffer);
1298            let compressor = Lz4Compressor::new();
1299
1300            let mut writer =
1301                HctWriterV2::new(cursor, CompressionAlgorithm::Lz4, DType::I8, vec![16384]);
1302            writer.compress_data(&original_data, &compressor).unwrap();
1303            writer.finish().unwrap();
1304        }
1305
1306        // Read with v2 reader and validate checksums
1307        let cursor = Cursor::new(&buffer);
1308        let mut reader = HctReaderV2::new(cursor).unwrap();
1309
1310        // Validate all checksums
1311        reader.validate_checksums().unwrap();
1312
1313        // Decompress with validation
1314        let cursor = Cursor::new(&buffer);
1315        let mut reader = HctReaderV2::new(cursor).unwrap();
1316        let decompressor = Lz4Decompressor::new();
1317        let decompressed = reader.decompress_all_validated(&decompressor).unwrap();
1318
1319        assert_eq!(decompressed, original_data);
1320    }
1321
1322    #[test]
1323    #[cfg(feature = "lz4")]
1324    fn test_hct_v2_checksum_detects_corruption() {
1325        use haagenti_lz4::Lz4Compressor;
1326        use std::io::Cursor;
1327
1328        // Create test data
1329        let original_data: Vec<u8> = (0..16384).map(|i| (i % 256) as u8).collect();
1330
1331        // Compress with v2 writer
1332        let mut buffer = Vec::new();
1333        {
1334            let cursor = Cursor::new(&mut buffer);
1335            let compressor = Lz4Compressor::new();
1336
1337            let mut writer =
1338                HctWriterV2::new(cursor, CompressionAlgorithm::Lz4, DType::I8, vec![16384]);
1339            writer.compress_data(&original_data, &compressor).unwrap();
1340            writer.finish().unwrap();
1341        }
1342
1343        // Corrupt a byte in the compressed data area
1344        // Skip header (64) + index entries (16 * num_blocks)
1345        let corruption_offset = 100; // Somewhere in index/data
1346        buffer[corruption_offset] ^= 0xFF;
1347
1348        // Try to read - should detect corruption
1349        let cursor = Cursor::new(&buffer);
1350        let result = HctReaderV2::new(cursor);
1351
1352        // Either header checksum fails or it parses but block validation fails
1353        match result {
1354            Err(_) => {
1355                // Header checksum failed - expected
1356            }
1357            Ok(mut reader) => {
1358                // Header passed, try to validate blocks
1359                let validate_result = reader.validate_checksums();
1360                assert!(validate_result.is_err(), "Should detect block corruption");
1361            }
1362        }
1363    }
1364
1365    #[test]
1366    #[cfg(feature = "lz4")]
1367    fn test_hct_v2_with_quantization_metadata() {
1368        use haagenti_lz4::{Lz4Compressor, Lz4Decompressor};
1369        use std::io::Cursor;
1370
1371        let original_data: Vec<u8> = (0..4096).map(|i| (i % 256) as u8).collect();
1372
1373        let quant = QuantizationMetadata {
1374            scheme: QuantizationScheme::GptqInt4,
1375            group_size: 128,
1376            scale_bits: 0x3C00,
1377            zero_point: 0,
1378            has_per_group_scales: false,
1379        };
1380
1381        // Compress with quantization metadata
1382        let mut buffer = Vec::new();
1383        {
1384            let cursor = Cursor::new(&mut buffer);
1385            let compressor = Lz4Compressor::new();
1386
1387            let mut writer =
1388                HctWriterV2::new(cursor, CompressionAlgorithm::Lz4, DType::I4, vec![4096])
1389                    .with_quantization(quant);
1390            writer.compress_data(&original_data, &compressor).unwrap();
1391            writer.finish().unwrap();
1392        }
1393
1394        // Read and verify quantization metadata
1395        let cursor = Cursor::new(&buffer);
1396        let mut reader = HctReaderV2::new(cursor).unwrap();
1397
1398        assert!(reader.quantization().is_some());
1399        let read_quant = reader.quantization().unwrap();
1400        assert_eq!(read_quant.scheme, QuantizationScheme::GptqInt4);
1401        assert_eq!(read_quant.group_size, 128);
1402        assert_eq!(read_quant.scale_bits, 0x3C00);
1403
1404        // Verify data integrity
1405        let decompressor = Lz4Decompressor::new();
1406        let decompressed = reader.decompress_all_validated(&decompressor).unwrap();
1407        assert_eq!(decompressed, original_data);
1408    }
1409
1410    #[test]
1411    #[cfg(feature = "lz4")]
1412    fn test_hct_v2_backward_compatible_with_v1_reader() {
1413        use haagenti_lz4::{Lz4Compressor, Lz4Decompressor};
1414        use std::io::Cursor;
1415
1416        // V1 files should still be readable
1417        let original_data: Vec<u8> = (0..8192).map(|i| (i % 256) as u8).collect();
1418
1419        // Write with v1 writer
1420        let mut buffer = Vec::new();
1421        {
1422            let cursor = Cursor::new(&mut buffer);
1423            let compressor = Lz4Compressor::new();
1424
1425            let mut writer =
1426                HctWriter::new(cursor, CompressionAlgorithm::Lz4, DType::I8, vec![8192]);
1427            writer.compress_data(&original_data, &compressor).unwrap();
1428            writer.finish().unwrap();
1429        }
1430
1431        // Read with v1 reader (should work as before)
1432        let cursor = Cursor::new(&buffer);
1433        let mut reader = HctReader::new(cursor).unwrap();
1434        let decompressor = Lz4Decompressor::new();
1435        let decompressed = reader.decompress_all(&decompressor).unwrap();
1436
1437        assert_eq!(decompressed, original_data);
1438    }
1439
1440    // ================================================================================
1441    // Phase 3: Format Edge Case Tests
1442    // ================================================================================
1443
1444    // -------------------- Corrupted Header Tests --------------------
1445
1446    #[test]
1447    fn test_corrupted_magic_number() {
1448        let mut bytes = HctHeader {
1449            algorithm: CompressionAlgorithm::Lz4,
1450            dtype: DType::F32,
1451            flags: 0,
1452            original_size: 1024,
1453            compressed_size: 512,
1454            block_size: 1024,
1455            num_blocks: 1,
1456            shape: vec![32, 32],
1457        }
1458        .to_bytes();
1459
1460        // Corrupt magic number
1461        bytes[0] = 0xFF;
1462        bytes[1] = 0xFF;
1463
1464        let result = HctHeader::from_bytes(&bytes);
1465        assert!(result.is_err(), "Should reject corrupted magic");
1466    }
1467
1468    #[test]
1469    fn test_corrupted_version() {
1470        let mut bytes = HctHeader {
1471            algorithm: CompressionAlgorithm::Lz4,
1472            dtype: DType::F32,
1473            flags: 0,
1474            original_size: 1024,
1475            compressed_size: 512,
1476            block_size: 1024,
1477            num_blocks: 1,
1478            shape: vec![32, 32],
1479        }
1480        .to_bytes();
1481
1482        // Corrupt version (byte 4)
1483        bytes[4] = 0xFF;
1484
1485        let result = HctHeader::from_bytes(&bytes);
1486        assert!(result.is_err(), "Should reject unsupported version");
1487    }
1488
1489    #[test]
1490    fn test_corrupted_algorithm_field() {
1491        let mut bytes = HctHeader {
1492            algorithm: CompressionAlgorithm::Lz4,
1493            dtype: DType::F32,
1494            flags: 0,
1495            original_size: 1024,
1496            compressed_size: 512,
1497            block_size: 1024,
1498            num_blocks: 1,
1499            shape: vec![32, 32],
1500        }
1501        .to_bytes();
1502
1503        // Set algorithm to invalid value (byte 5)
1504        bytes[5] = 0xFF;
1505
1506        let result = HctHeader::from_bytes(&bytes);
1507        assert!(result.is_err(), "Should reject invalid algorithm");
1508    }
1509
1510    #[test]
1511    fn test_corrupted_dtype_field() {
1512        let mut bytes = HctHeader {
1513            algorithm: CompressionAlgorithm::Lz4,
1514            dtype: DType::F32,
1515            flags: 0,
1516            original_size: 1024,
1517            compressed_size: 512,
1518            block_size: 1024,
1519            num_blocks: 1,
1520            shape: vec![32, 32],
1521        }
1522        .to_bytes();
1523
1524        // Set dtype to invalid value (byte 6)
1525        bytes[6] = 0xFF;
1526
1527        let result = HctHeader::from_bytes(&bytes);
1528        assert!(result.is_err(), "Should reject invalid dtype");
1529    }
1530
1531    // -------------------- Truncated Data Tests --------------------
1532
1533    #[test]
1534    fn test_truncated_header() {
1535        // The HctHeader::from_bytes expects exactly SIZE bytes.
1536        // Create a buffer that's too small to test boundary behavior.
1537        let small_buf: [u8; 16] = [0; 16];
1538
1539        // HctHeader::from_bytes expects exactly 64 bytes
1540        // This test verifies that the SIZE constraint is correct
1541        assert!(
1542            HctHeader::SIZE >= 32,
1543            "Header should require at least 32 bytes"
1544        );
1545    }
1546
1547    // -------------------- All Quantization Schemes --------------------
1548
1549    #[test]
1550    fn test_quantization_scheme_symmetric_int8() {
1551        let quant = QuantizationMetadata {
1552            scheme: QuantizationScheme::SymmetricInt8,
1553            group_size: 64,
1554            scale_bits: 0x4000, // 2.0 in f16
1555            zero_point: 0,
1556            has_per_group_scales: false,
1557        };
1558
1559        let bytes = quant.to_bytes();
1560        let parsed = QuantizationMetadata::from_bytes(&bytes).unwrap();
1561
1562        assert_eq!(parsed.scheme, QuantizationScheme::SymmetricInt8);
1563        assert_eq!(parsed.group_size, 64);
1564        assert_eq!(parsed.zero_point, 0);
1565    }
1566
1567    #[test]
1568    fn test_quantization_scheme_asymmetric_int8() {
1569        let quant = QuantizationMetadata {
1570            scheme: QuantizationScheme::AsymmetricInt8,
1571            group_size: 32,
1572            scale_bits: 0x3C00,
1573            zero_point: -128, // Max negative value for i8
1574            has_per_group_scales: true,
1575        };
1576
1577        let bytes = quant.to_bytes();
1578        let parsed = QuantizationMetadata::from_bytes(&bytes).unwrap();
1579
1580        assert_eq!(parsed.scheme, QuantizationScheme::AsymmetricInt8);
1581        assert_eq!(parsed.zero_point, -128);
1582        assert!(parsed.has_per_group_scales);
1583    }
1584
1585    #[test]
1586    fn test_quantization_scheme_awq_int4() {
1587        let quant = QuantizationMetadata {
1588            scheme: QuantizationScheme::AwqInt4,
1589            group_size: 128,
1590            scale_bits: 0x3800, // 0.5 in f16
1591            zero_point: 8,
1592            has_per_group_scales: true,
1593        };
1594
1595        let bytes = quant.to_bytes();
1596        let parsed = QuantizationMetadata::from_bytes(&bytes).unwrap();
1597
1598        assert_eq!(parsed.scheme, QuantizationScheme::AwqInt4);
1599        assert_eq!(parsed.group_size, 128);
1600    }
1601
1602    #[test]
1603    fn test_quantization_scheme_gptq_int4() {
1604        let quant = QuantizationMetadata {
1605            scheme: QuantizationScheme::GptqInt4,
1606            group_size: 128,
1607            scale_bits: 0x3C00,
1608            zero_point: 0,
1609            has_per_group_scales: true,
1610        };
1611
1612        let bytes = quant.to_bytes();
1613        let parsed = QuantizationMetadata::from_bytes(&bytes).unwrap();
1614
1615        assert_eq!(parsed.scheme, QuantizationScheme::GptqInt4);
1616    }
1617
1618    #[test]
1619    fn test_quantization_scheme_none() {
1620        let quant = QuantizationMetadata {
1621            scheme: QuantizationScheme::None,
1622            group_size: 0,
1623            scale_bits: 0,
1624            zero_point: 0,
1625            has_per_group_scales: false,
1626        };
1627
1628        let bytes = quant.to_bytes();
1629        let parsed = QuantizationMetadata::from_bytes(&bytes).unwrap();
1630
1631        assert_eq!(parsed.scheme, QuantizationScheme::None);
1632    }
1633
1634    // -------------------- Block Boundary Edge Cases --------------------
1635
1636    #[test]
1637    fn test_header_zero_blocks() {
1638        let header = HctHeader {
1639            algorithm: CompressionAlgorithm::Lz4,
1640            dtype: DType::F32,
1641            flags: 0,
1642            original_size: 0,
1643            compressed_size: 0,
1644            block_size: 1024,
1645            num_blocks: 0,
1646            shape: vec![0],
1647        };
1648
1649        let bytes = header.to_bytes();
1650        let parsed = HctHeader::from_bytes(&bytes).unwrap();
1651
1652        assert_eq!(parsed.num_blocks, 0);
1653        assert_eq!(parsed.original_size, 0);
1654    }
1655
1656    #[test]
1657    fn test_header_single_block() {
1658        let header = HctHeader {
1659            algorithm: CompressionAlgorithm::Zstd,
1660            dtype: DType::F32,
1661            flags: 0,
1662            original_size: 512,
1663            compressed_size: 256,
1664            block_size: 1024,
1665            num_blocks: 1,
1666            shape: vec![128],
1667        };
1668
1669        let bytes = header.to_bytes();
1670        let parsed = HctHeader::from_bytes(&bytes).unwrap();
1671
1672        assert_eq!(parsed.num_blocks, 1);
1673        // Data is smaller than block size
1674        assert!(parsed.original_size < parsed.block_size as u64);
1675    }
1676
1677    #[test]
1678    fn test_header_exact_block_multiple() {
1679        let header = HctHeader {
1680            algorithm: CompressionAlgorithm::Lz4,
1681            dtype: DType::I8,
1682            flags: 0,
1683            original_size: 4096,
1684            compressed_size: 2048,
1685            block_size: 1024,
1686            num_blocks: 4,
1687            shape: vec![4096],
1688        };
1689
1690        let bytes = header.to_bytes();
1691        let parsed = HctHeader::from_bytes(&bytes).unwrap();
1692
1693        // 4096 / 1024 = 4 blocks exactly
1694        assert_eq!(parsed.num_blocks, 4);
1695        assert_eq!(parsed.original_size, 4 * parsed.block_size as u64);
1696    }
1697
1698    #[test]
1699    fn test_header_partial_final_block() {
1700        let header = HctHeader {
1701            algorithm: CompressionAlgorithm::Lz4,
1702            dtype: DType::I8,
1703            flags: 0,
1704            original_size: 4500,
1705            compressed_size: 2250,
1706            block_size: 1024,
1707            num_blocks: 5,
1708            shape: vec![4500],
1709        };
1710
1711        let bytes = header.to_bytes();
1712        let parsed = HctHeader::from_bytes(&bytes).unwrap();
1713
1714        // 4500 / 1024 = 4.39... -> 5 blocks
1715        // Last block has 4500 - 4*1024 = 404 bytes
1716        assert_eq!(parsed.num_blocks, 5);
1717        let last_block_size = parsed.original_size as u32 % parsed.block_size;
1718        assert_eq!(last_block_size, 404);
1719    }
1720
1721    // -------------------- Shape Dimension Tests --------------------
1722
1723    #[test]
1724    fn test_header_1d_shape() {
1725        let header = HctHeader {
1726            algorithm: CompressionAlgorithm::Lz4,
1727            dtype: DType::F32,
1728            flags: 0,
1729            original_size: 4096,
1730            compressed_size: 2048,
1731            block_size: 1024,
1732            num_blocks: 4,
1733            shape: vec![1024],
1734        };
1735
1736        let bytes = header.to_bytes();
1737        let parsed = HctHeader::from_bytes(&bytes).unwrap();
1738
1739        assert_eq!(parsed.shape.len(), 1);
1740        assert_eq!(parsed.shape[0], 1024);
1741    }
1742
1743    #[test]
1744    fn test_header_2d_shape() {
1745        let header = HctHeader {
1746            algorithm: CompressionAlgorithm::Lz4,
1747            dtype: DType::F32,
1748            flags: 0,
1749            original_size: 4096,
1750            compressed_size: 2048,
1751            block_size: 1024,
1752            num_blocks: 4,
1753            shape: vec![32, 32],
1754        };
1755
1756        let bytes = header.to_bytes();
1757        let parsed = HctHeader::from_bytes(&bytes).unwrap();
1758
1759        assert_eq!(parsed.shape.len(), 2);
1760        assert_eq!(parsed.shape, vec![32, 32]);
1761    }
1762
1763    #[test]
1764    fn test_header_3d_shape() {
1765        let header = HctHeader {
1766            algorithm: CompressionAlgorithm::Lz4,
1767            dtype: DType::F32,
1768            flags: 0,
1769            original_size: 4096,
1770            compressed_size: 2048,
1771            block_size: 1024,
1772            num_blocks: 4,
1773            shape: vec![4, 16, 64],
1774        };
1775
1776        let bytes = header.to_bytes();
1777        let parsed = HctHeader::from_bytes(&bytes).unwrap();
1778
1779        assert_eq!(parsed.shape.len(), 3);
1780        assert_eq!(parsed.shape, vec![4, 16, 64]);
1781    }
1782
1783    #[test]
1784    fn test_header_max_3_dimensions() {
1785        // The header format has 64 bytes total:
1786        // - 37 bytes fixed fields
1787        // - 27 bytes remaining for shape (3 dimensions * 8 bytes = 24, plus padding)
1788        // 4D shapes would need 69 bytes which exceeds the header size
1789        // The implementation truncates to 4 dimensions but only 3 fit properly
1790
1791        // Verify header size constraint
1792        assert_eq!(HctHeader::SIZE, 64);
1793
1794        // Shape storage: rank at byte 36, dimensions starting at byte 37
1795        // For 3 dimensions: 37 + 3*8 = 61 bytes (fits)
1796        // For 4 dimensions: 37 + 4*8 = 69 bytes (overflow!)
1797
1798        // Test that 3D with large values works
1799        let header = HctHeader {
1800            algorithm: CompressionAlgorithm::Zstd,
1801            dtype: DType::BF16,
1802            flags: 0,
1803            original_size: u64::MAX / 2,
1804            compressed_size: u64::MAX / 4,
1805            block_size: u32::MAX,
1806            num_blocks: u32::MAX / 2,
1807            shape: vec![8192, 8192, 128], // Large 3D tensor
1808        };
1809
1810        let bytes = header.to_bytes();
1811        let parsed = HctHeader::from_bytes(&bytes).unwrap();
1812
1813        assert_eq!(parsed.shape.len(), 3);
1814        assert_eq!(parsed.shape[0], 8192);
1815        assert_eq!(parsed.shape[1], 8192);
1816        assert_eq!(parsed.shape[2], 128);
1817    }
1818
1819    // -------------------- DType Tests --------------------
1820
1821    #[test]
1822    fn test_all_dtypes_roundtrip() {
1823        let dtypes = [DType::F32, DType::F16, DType::BF16, DType::I8, DType::I4];
1824
1825        for dtype in dtypes {
1826            let header = HctHeader {
1827                algorithm: CompressionAlgorithm::Lz4,
1828                dtype,
1829                flags: 0,
1830                original_size: 1024,
1831                compressed_size: 512,
1832                block_size: 1024,
1833                num_blocks: 1,
1834                shape: vec![256],
1835            };
1836
1837            let bytes = header.to_bytes();
1838            let parsed = HctHeader::from_bytes(&bytes).unwrap();
1839
1840            assert_eq!(parsed.dtype, dtype, "DType {:?} should roundtrip", dtype);
1841        }
1842    }
1843
1844    // -------------------- Algorithm Tests --------------------
1845
1846    #[test]
1847    fn test_all_algorithms_roundtrip() {
1848        let algorithms = [CompressionAlgorithm::Lz4, CompressionAlgorithm::Zstd];
1849
1850        for algorithm in algorithms {
1851            let header = HctHeader {
1852                algorithm,
1853                dtype: DType::F32,
1854                flags: 0,
1855                original_size: 1024,
1856                compressed_size: 512,
1857                block_size: 1024,
1858                num_blocks: 1,
1859                shape: vec![256],
1860            };
1861
1862            let bytes = header.to_bytes();
1863            let parsed = HctHeader::from_bytes(&bytes).unwrap();
1864
1865            assert_eq!(
1866                parsed.algorithm, algorithm,
1867                "Algorithm {:?} should roundtrip",
1868                algorithm
1869            );
1870        }
1871    }
1872
1873    // -------------------- Flags Tests --------------------
1874
1875    #[test]
1876    fn test_flags_preserved() {
1877        let flags_to_test = [0x0000, 0x0001, 0x0002, 0xFFFF];
1878
1879        for flags in flags_to_test {
1880            let header = HctHeader {
1881                algorithm: CompressionAlgorithm::Lz4,
1882                dtype: DType::F32,
1883                flags,
1884                original_size: 1024,
1885                compressed_size: 512,
1886                block_size: 1024,
1887                num_blocks: 1,
1888                shape: vec![256],
1889            };
1890
1891            let bytes = header.to_bytes();
1892            let parsed = HctHeader::from_bytes(&bytes).unwrap();
1893
1894            assert_eq!(
1895                parsed.flags, flags,
1896                "Flags {:04X} should be preserved",
1897                flags
1898            );
1899        }
1900    }
1901
1902    // -------------------- Large Value Tests --------------------
1903
1904    #[test]
1905    fn test_large_original_size() {
1906        let header = HctHeader {
1907            algorithm: CompressionAlgorithm::Lz4,
1908            dtype: DType::F32,
1909            flags: 0,
1910            original_size: u64::MAX - 1,
1911            compressed_size: u64::MAX / 2,
1912            block_size: u32::MAX,
1913            num_blocks: u32::MAX,
1914            shape: vec![u64::MAX],
1915        };
1916
1917        let bytes = header.to_bytes();
1918        let parsed = HctHeader::from_bytes(&bytes).unwrap();
1919
1920        assert_eq!(parsed.original_size, u64::MAX - 1);
1921        assert_eq!(parsed.compressed_size, u64::MAX / 2);
1922        assert_eq!(parsed.block_size, u32::MAX);
1923        assert_eq!(parsed.num_blocks, u32::MAX);
1924    }
1925
1926    // -------------------- Block Index Tests --------------------
1927
1928    #[test]
1929    fn test_block_index_large_values() {
1930        let index = BlockIndex {
1931            offset: u32::MAX - 1,
1932            compressed_size: u32::MAX - 1,
1933        };
1934
1935        let bytes = index.to_bytes();
1936        let parsed = BlockIndex::from_bytes(&bytes);
1937
1938        assert_eq!(parsed.offset, u32::MAX - 1);
1939        assert_eq!(parsed.compressed_size, u32::MAX - 1);
1940    }
1941
1942    #[test]
1943    fn test_block_index_v2_checksum_uniqueness() {
1944        let index1 = BlockIndexV2 {
1945            offset: 100,
1946            compressed_size: 50,
1947            checksum: 0xABCD_EF01_2345_6789,
1948        };
1949
1950        let index2 = BlockIndexV2 {
1951            offset: 100,
1952            compressed_size: 50,
1953            checksum: 0x9876_5432_10FE_DCBA,
1954        };
1955
1956        let bytes1 = index1.to_bytes();
1957        let bytes2 = index2.to_bytes();
1958
1959        // Same offset/size but different checksum should produce different bytes
1960        assert_ne!(bytes1, bytes2);
1961
1962        // And roundtrip correctly
1963        let parsed1 = BlockIndexV2::from_bytes(&bytes1);
1964        let parsed2 = BlockIndexV2::from_bytes(&bytes2);
1965
1966        assert_eq!(parsed1.checksum, index1.checksum);
1967        assert_eq!(parsed2.checksum, index2.checksum);
1968    }
1969
1970    // -------------------- Error Condition Tests --------------------
1971
1972    #[test]
1973    fn test_reader_invalid_block_index_bounds() {
1974        // Create a valid HCT file in memory
1975        let data = vec![0u8; 4096]; // Some data to compress
1976        let mut output = Vec::new();
1977
1978        {
1979            let cursor = std::io::Cursor::new(&mut output);
1980            let mut writer =
1981                HctWriter::new(cursor, CompressionAlgorithm::Lz4, DType::F32, vec![64, 16])
1982                    .with_block_size(1024);
1983
1984            let codec = haagenti_lz4::Lz4Codec::new();
1985            writer.compress_data(&data, &codec).unwrap();
1986            writer.finish().unwrap();
1987        }
1988
1989        // Read it back and try to access invalid block
1990        let cursor = std::io::Cursor::new(&output);
1991        let mut reader = HctReader::new(cursor).unwrap();
1992
1993        // Try to read a block that doesn't exist
1994        let result = reader.read_block(999);
1995        assert!(result.is_err(), "Should error on invalid block index");
1996        let err_msg = format!("{}", result.unwrap_err());
1997        assert!(
1998            err_msg.contains("block index out of range") || err_msg.contains("corrupted"),
1999            "Error should mention invalid block: {}",
2000            err_msg
2001        );
2002    }
2003
2004    #[test]
2005    fn test_reader_truncated_header() {
2006        // Create a header that's too short
2007        let short_data = vec![0u8; HctHeader::SIZE - 10];
2008        let cursor = std::io::Cursor::new(short_data);
2009
2010        let result = HctReader::new(cursor);
2011        assert!(result.is_err(), "Should error on truncated header");
2012    }
2013
2014    #[test]
2015    fn test_reader_truncated_block_index() {
2016        // Create a header with 10 blocks, but truncate the block index section
2017        let mut data = [0u8; HctHeader::SIZE];
2018
2019        // Write valid magic
2020        data[0..4].copy_from_slice(&HCT_MAGIC);
2021        // Write valid version
2022        data[4..8].copy_from_slice(&HCT_VERSION.to_le_bytes());
2023        // Algorithm (LZ4 = 0)
2024        data[8] = 0;
2025        // DType (F32 = 0)
2026        data[9] = 0;
2027        // num_blocks = 10
2028        data[32..36].copy_from_slice(&10u32.to_le_bytes());
2029        // rank = 1
2030        data[36] = 1;
2031
2032        // Only provide header, no block index data
2033        let cursor = std::io::Cursor::new(data.to_vec());
2034
2035        let result = HctReader::new(cursor);
2036        assert!(
2037            result.is_err(),
2038            "Should error when block index is truncated"
2039        );
2040    }
2041
2042    #[test]
2043    fn test_v2_checksum_validation_detects_bitflip() {
2044        // Create a valid v2 HCT file
2045        let data = vec![42u8; 2048]; // Some data
2046        let mut output = Vec::new();
2047
2048        {
2049            let cursor = std::io::Cursor::new(&mut output);
2050            let mut writer =
2051                HctWriterV2::new(cursor, CompressionAlgorithm::Zstd, DType::F32, vec![32, 16])
2052                    .with_block_size(1024);
2053
2054            let codec = haagenti_zstd::ZstdCodec::new();
2055            writer.compress_data(&data, &codec).unwrap();
2056            writer.finish().unwrap();
2057        }
2058
2059        // Corrupt a byte in the compressed data section (after header and index)
2060        let header_size = HctHeader::SIZE;
2061        // Find where block data starts (after header + block index entries)
2062        // For v2: header + (num_blocks * 16) for block index
2063        if output.len() > header_size + 32 {
2064            let corrupt_pos = header_size + 50; // Somewhere in block index or data
2065            if corrupt_pos < output.len() {
2066                output[corrupt_pos] ^= 0xFF; // Flip all bits
2067            }
2068        }
2069
2070        // Try to read - should detect corruption in v2 reader
2071        let cursor = std::io::Cursor::new(&output);
2072        let reader_result = HctReaderV2::new(cursor);
2073
2074        // The corruption might be detected during:
2075        // 1. Block index parsing (if we corrupted index)
2076        // 2. Block read with checksum validation (if we corrupted data)
2077        // Either way, corruption should eventually be detected
2078        if let Ok(mut reader) = reader_result {
2079            // Try to read the block - checksum validation should fail
2080            let block_result = reader.read_block_validated(0);
2081            // May or may not error depending on what we corrupted
2082            // The important thing is the code handles it without panicking
2083            let _ = block_result;
2084        }
2085    }
2086
2087    #[test]
2088    fn test_empty_data_compression() {
2089        // Compress empty data
2090        let data: Vec<u8> = vec![];
2091        let mut output = Vec::new();
2092
2093        {
2094            let cursor = std::io::Cursor::new(&mut output);
2095            let mut writer = HctWriter::new(
2096                cursor,
2097                CompressionAlgorithm::Lz4,
2098                DType::F32,
2099                vec![0], // Empty shape
2100            )
2101            .with_block_size(1024);
2102
2103            let codec = haagenti_lz4::Lz4Codec::new();
2104            writer.compress_data(&data, &codec).unwrap();
2105            writer.finish().unwrap();
2106        }
2107
2108        // Read it back
2109        let cursor = std::io::Cursor::new(&output);
2110        let mut reader = HctReader::new(cursor).unwrap();
2111
2112        assert_eq!(reader.header().num_blocks, 0);
2113        assert_eq!(reader.header().original_size, 0);
2114    }
2115
2116    #[test]
2117    fn test_reader_with_completely_invalid_data() {
2118        // Random garbage that's not a valid HCT file
2119        let garbage = vec![0xDE, 0xAD, 0xBE, 0xEF, 0x12, 0x34, 0x56, 0x78];
2120        let cursor = std::io::Cursor::new(garbage);
2121
2122        let result = HctReader::new(cursor);
2123        assert!(result.is_err(), "Should reject invalid data");
2124    }
2125
2126    #[test]
2127    fn test_writer_multiple_compressions() {
2128        // Test that we can't call compress_data after finish
2129        let mut output = Vec::new();
2130        let data = vec![1u8; 100];
2131
2132        let cursor = std::io::Cursor::new(&mut output);
2133        let mut writer = HctWriter::new(cursor, CompressionAlgorithm::Lz4, DType::F32, vec![100])
2134            .with_block_size(64);
2135
2136        let codec = haagenti_lz4::Lz4Codec::new();
2137        writer.compress_data(&data, &codec).unwrap();
2138        // First finish should succeed
2139        writer.finish().unwrap();
2140
2141        // Output should have valid structure
2142        let cursor = std::io::Cursor::new(&output);
2143        let reader = HctReader::new(cursor);
2144        assert!(
2145            reader.is_ok(),
2146            "Should be able to read back compressed data"
2147        );
2148    }
2149
2150    #[test]
2151    fn test_block_boundary_at_exact_size() {
2152        // Data size exactly matches block size
2153        let block_size = 256;
2154        let data = vec![0xABu8; block_size];
2155        let mut output = Vec::new();
2156
2157        {
2158            let cursor = std::io::Cursor::new(&mut output);
2159            let mut writer = HctWriter::new(
2160                cursor,
2161                CompressionAlgorithm::Lz4,
2162                DType::F32,
2163                vec![block_size as u64],
2164            )
2165            .with_block_size(block_size as u32);
2166
2167            let codec = haagenti_lz4::Lz4Codec::new();
2168            writer.compress_data(&data, &codec).unwrap();
2169            writer.finish().unwrap();
2170        }
2171
2172        let cursor = std::io::Cursor::new(&output);
2173        let reader = HctReader::new(cursor).unwrap();
2174
2175        // Should have exactly 1 block
2176        assert_eq!(reader.header().num_blocks, 1);
2177    }
2178
2179    #[test]
2180    fn test_block_boundary_at_size_plus_one() {
2181        // Data size is exactly block size + 1 (forces 2 blocks)
2182        let block_size = 256;
2183        let data = vec![0xCDu8; block_size + 1];
2184        let mut output = Vec::new();
2185
2186        {
2187            let cursor = std::io::Cursor::new(&mut output);
2188            let mut writer = HctWriter::new(
2189                cursor,
2190                CompressionAlgorithm::Lz4,
2191                DType::F32,
2192                vec![(block_size + 1) as u64],
2193            )
2194            .with_block_size(block_size as u32);
2195
2196            let codec = haagenti_lz4::Lz4Codec::new();
2197            writer.compress_data(&data, &codec).unwrap();
2198            writer.finish().unwrap();
2199        }
2200
2201        let cursor = std::io::Cursor::new(&output);
2202        let reader = HctReader::new(cursor).unwrap();
2203
2204        // Should have exactly 2 blocks
2205        assert_eq!(reader.header().num_blocks, 2);
2206    }
2207
2208    #[test]
2209    fn test_decompression_with_wrong_algorithm() {
2210        // Compress with LZ4
2211        let data = vec![0x12u8; 512];
2212        let mut output = Vec::new();
2213
2214        {
2215            let cursor = std::io::Cursor::new(&mut output);
2216            let mut writer =
2217                HctWriter::new(cursor, CompressionAlgorithm::Lz4, DType::F32, vec![512])
2218                    .with_block_size(256);
2219
2220            let codec = haagenti_lz4::Lz4Codec::new();
2221            writer.compress_data(&data, &codec).unwrap();
2222            writer.finish().unwrap();
2223        }
2224
2225        // Read back
2226        let cursor = std::io::Cursor::new(&output);
2227        let mut reader = HctReader::new(cursor).unwrap();
2228
2229        // The header correctly reports LZ4
2230        assert_eq!(reader.header().algorithm, CompressionAlgorithm::Lz4);
2231
2232        // Decompress with the correct algorithm should work
2233        let lz4 = haagenti_lz4::Lz4Codec::new();
2234        let result = reader.decompress_all(&lz4);
2235        assert!(
2236            result.is_ok(),
2237            "Decompression with correct algorithm should work"
2238        );
2239        assert_eq!(result.unwrap(), data);
2240    }
2241}