Skip to main content

haagenti_python/
lib.rs

1// PyO3 deprecation warnings - these require PyO3 version updates to fix properly
2#![allow(deprecated)]
3// Allow manual div_ceil in numeric code
4#![allow(clippy::manual_div_ceil)]
5// Unused import from numpy feature flags
6#![allow(unused_imports)]
7// PyO3 error handling macro false positives
8#![allow(clippy::useless_conversion)]
9// PyO3 cfg condition for gil-refs feature
10#![allow(unexpected_cfgs)]
11
12//! Python bindings for Haagenti tensor compression library.
13//!
14//! Provides:
15//! - HCT format reading/writing for tensor storage
16//! - HoloTensor progressive encoding/decoding
17//! - LZ4/Zstd compression backends
18//!
19//! # Example (Python)
20//! ```python
21//! from haagenti import HctReader, CompressionAlgorithm, DType
22//!
23//! # Read an HCT file
24//! reader = HctReader("model.hct")
25//! header = reader.header()
26//! print(f"Shape: {header.shape}, DType: {header.dtype}")
27//!
28//! # Decompress all data
29//! data = reader.decompress_all()
30//! ```
31
32use numpy::{IntoPyArray, PyArray1, PyReadonlyArray1, ToPyArray};
33use pyo3::exceptions::{PyIOError, PyValueError};
34use pyo3::prelude::*;
35use std::fs::File;
36use std::io::{BufReader, BufWriter};
37
38// Re-exports from haagenti
39use haagenti::{
40    CompressionAlgorithm as RustCompressionAlgorithm, Compressor, DType as RustDType, Decompressor,
41    HctHeader as RustHctHeader, HctReaderV2, HctWriterV2,
42    HolographicEncoding as RustHolographicEncoding, QuantizationScheme as RustQuantizationScheme,
43};
44
45// Type aliases for V2 readers/writers
46type RustHctReaderV2<R> = HctReaderV2<R>;
47type RustHctWriterV2<W> = HctWriterV2<W>;
48
49use haagenti_lz4::{Lz4Compressor, Lz4Decompressor};
50use haagenti_zstd::{ZstdCompressor, ZstdDecompressor};
51
52// ============================================================================
53// Enums
54// ============================================================================
55
56/// Compression algorithm for HCT files.
57#[pyclass]
58#[derive(Clone, Copy, Debug, PartialEq)]
59pub enum CompressionAlgorithm {
60    Lz4,
61    Zstd,
62}
63
64impl From<RustCompressionAlgorithm> for CompressionAlgorithm {
65    fn from(algo: RustCompressionAlgorithm) -> Self {
66        match algo {
67            RustCompressionAlgorithm::Lz4 => CompressionAlgorithm::Lz4,
68            RustCompressionAlgorithm::Zstd => CompressionAlgorithm::Zstd,
69        }
70    }
71}
72
73impl From<CompressionAlgorithm> for RustCompressionAlgorithm {
74    fn from(algo: CompressionAlgorithm) -> Self {
75        match algo {
76            CompressionAlgorithm::Lz4 => RustCompressionAlgorithm::Lz4,
77            CompressionAlgorithm::Zstd => RustCompressionAlgorithm::Zstd,
78        }
79    }
80}
81
82#[pymethods]
83impl CompressionAlgorithm {
84    fn __repr__(&self) -> String {
85        format!("CompressionAlgorithm.{:?}", self)
86    }
87}
88
89/// Data type for tensor elements.
90#[pyclass]
91#[derive(Clone, Copy, Debug, PartialEq)]
92pub enum DType {
93    F32,
94    F16,
95    BF16,
96    I8,
97    I4,
98}
99
100impl From<RustDType> for DType {
101    fn from(dtype: RustDType) -> Self {
102        match dtype {
103            RustDType::F32 => DType::F32,
104            RustDType::F16 => DType::F16,
105            RustDType::BF16 => DType::BF16,
106            RustDType::I8 => DType::I8,
107            RustDType::I4 => DType::I4,
108        }
109    }
110}
111
112impl From<DType> for RustDType {
113    fn from(dtype: DType) -> Self {
114        match dtype {
115            DType::F32 => RustDType::F32,
116            DType::F16 => RustDType::F16,
117            DType::BF16 => RustDType::BF16,
118            DType::I8 => RustDType::I8,
119            DType::I4 => RustDType::I4,
120        }
121    }
122}
123
124#[pymethods]
125impl DType {
126    /// Bits per element for this dtype.
127    fn bits(&self) -> u32 {
128        match self {
129            DType::F32 => 32,
130            DType::F16 | DType::BF16 => 16,
131            DType::I8 => 8,
132            DType::I4 => 4,
133        }
134    }
135
136    /// Bytes per element (rounded up for sub-byte types).
137    fn bytes(&self) -> u32 {
138        (self.bits() + 7) / 8
139    }
140
141    fn __repr__(&self) -> String {
142        format!("DType.{:?}", self)
143    }
144}
145
146/// Quantization scheme for compressed tensors.
147#[pyclass]
148#[derive(Clone, Copy, Debug, PartialEq)]
149pub enum QuantizationScheme {
150    None,
151    GptqInt4,
152    AwqInt4,
153    SymmetricInt8,
154    AsymmetricInt8,
155}
156
157impl From<RustQuantizationScheme> for QuantizationScheme {
158    fn from(scheme: RustQuantizationScheme) -> Self {
159        match scheme {
160            RustQuantizationScheme::None => QuantizationScheme::None,
161            RustQuantizationScheme::GptqInt4 => QuantizationScheme::GptqInt4,
162            RustQuantizationScheme::AwqInt4 => QuantizationScheme::AwqInt4,
163            RustQuantizationScheme::SymmetricInt8 => QuantizationScheme::SymmetricInt8,
164            RustQuantizationScheme::AsymmetricInt8 => QuantizationScheme::AsymmetricInt8,
165        }
166    }
167}
168
169#[pymethods]
170impl QuantizationScheme {
171    fn __repr__(&self) -> String {
172        format!("QuantizationScheme.{:?}", self)
173    }
174}
175
176/// Holographic encoding type.
177#[pyclass]
178#[derive(Clone, Copy, Debug, PartialEq)]
179pub enum HolographicEncoding {
180    /// DCT-based spectral encoding (best for smooth weights)
181    Spectral,
182    /// Random projection hash (Johnson-Lindenstrauss)
183    RandomProjection,
184    /// Low-rank distributed factorization (SVD-based)
185    LowRankDistributed,
186}
187
188impl From<RustHolographicEncoding> for HolographicEncoding {
189    fn from(enc: RustHolographicEncoding) -> Self {
190        match enc {
191            RustHolographicEncoding::Spectral => HolographicEncoding::Spectral,
192            RustHolographicEncoding::RandomProjection => HolographicEncoding::RandomProjection,
193            RustHolographicEncoding::LowRankDistributed => HolographicEncoding::LowRankDistributed,
194        }
195    }
196}
197
198impl From<HolographicEncoding> for RustHolographicEncoding {
199    fn from(enc: HolographicEncoding) -> Self {
200        match enc {
201            HolographicEncoding::Spectral => RustHolographicEncoding::Spectral,
202            HolographicEncoding::RandomProjection => RustHolographicEncoding::RandomProjection,
203            HolographicEncoding::LowRankDistributed => RustHolographicEncoding::LowRankDistributed,
204        }
205    }
206}
207
208#[pymethods]
209impl HolographicEncoding {
210    fn __repr__(&self) -> String {
211        format!("HolographicEncoding.{:?}", self)
212    }
213}
214
215// ============================================================================
216// HCT Header
217// ============================================================================
218
219/// Header information for an HCT file.
220#[pyclass]
221#[derive(Clone)]
222pub struct HctHeader {
223    #[pyo3(get)]
224    pub algorithm: CompressionAlgorithm,
225    #[pyo3(get)]
226    pub dtype: DType,
227    #[pyo3(get)]
228    pub shape: Vec<u64>,
229    #[pyo3(get)]
230    pub original_size: u64,
231    #[pyo3(get)]
232    pub compressed_size: u64,
233    #[pyo3(get)]
234    pub block_size: u32,
235    #[pyo3(get)]
236    pub num_blocks: u32,
237}
238
239impl From<&RustHctHeader> for HctHeader {
240    fn from(header: &RustHctHeader) -> Self {
241        HctHeader {
242            algorithm: header.algorithm.into(),
243            dtype: header.dtype.into(),
244            shape: header.shape.clone(),
245            original_size: header.original_size,
246            compressed_size: header.compressed_size,
247            block_size: header.block_size,
248            num_blocks: header.num_blocks,
249        }
250    }
251}
252
253#[pymethods]
254impl HctHeader {
255    /// Total number of elements in the tensor.
256    fn numel(&self) -> u64 {
257        self.shape.iter().product()
258    }
259
260    /// Compression ratio (original / compressed).
261    fn compression_ratio(&self) -> f64 {
262        if self.compressed_size == 0 {
263            0.0
264        } else {
265            self.original_size as f64 / self.compressed_size as f64
266        }
267    }
268
269    fn __repr__(&self) -> String {
270        format!(
271            "HctHeader(dtype={:?}, shape={:?}, ratio={:.2}x)",
272            self.dtype,
273            self.shape,
274            self.compression_ratio()
275        )
276    }
277}
278
279// ============================================================================
280// HCT Reader
281// ============================================================================
282
283/// Reader for HCT (Haagenti Compressed Tensor) files.
284///
285/// Supports both V1 and V2 formats with optional checksum validation.
286#[pyclass]
287pub struct HctReader {
288    reader: HctReaderV2<BufReader<File>>,
289    path: String,
290}
291
292#[pymethods]
293impl HctReader {
294    /// Open an HCT file for reading.
295    #[new]
296    fn new(path: &str) -> PyResult<Self> {
297        let file = File::open(path)
298            .map_err(|e| PyIOError::new_err(format!("Failed to open {}: {}", path, e)))?;
299        let buf_reader = BufReader::new(file);
300        let reader = RustHctReaderV2::new(buf_reader)
301            .map_err(|e| PyIOError::new_err(format!("Failed to read HCT header: {}", e)))?;
302        Ok(HctReader {
303            reader,
304            path: path.to_string(),
305        })
306    }
307
308    /// Get the file header.
309    fn header(&self) -> HctHeader {
310        HctHeader::from(self.reader.header())
311    }
312
313    /// Number of compressed blocks.
314    fn num_blocks(&self) -> usize {
315        self.reader.num_blocks()
316    }
317
318    /// Decompress all blocks and return as numpy array (float32).
319    fn decompress_all<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<f32>>> {
320        // Clone header to avoid borrowing issues
321        let algorithm = self.reader.header().algorithm;
322        let dtype = self.reader.header().dtype;
323
324        // Create appropriate decompressor
325        let data = match algorithm {
326            RustCompressionAlgorithm::Lz4 => {
327                let decompressor = Lz4Decompressor::new();
328                self.reader
329                    .decompress_all_validated(&decompressor)
330                    .map_err(|e| PyIOError::new_err(format!("Decompression failed: {}", e)))?
331            }
332            RustCompressionAlgorithm::Zstd => {
333                let decompressor = ZstdDecompressor::new();
334                self.reader
335                    .decompress_all_validated(&decompressor)
336                    .map_err(|e| PyIOError::new_err(format!("Decompression failed: {}", e)))?
337            }
338        };
339
340        // Convert bytes to f32 based on dtype
341        let floats = bytes_to_f32(&data, dtype)?;
342        Ok(floats.into_pyarray_bound(py))
343    }
344
345    /// Decompress a single block by index.
346    fn decompress_block<'py>(
347        &mut self,
348        py: Python<'py>,
349        block_idx: usize,
350    ) -> PyResult<Bound<'py, PyArray1<u8>>> {
351        let algorithm = self.reader.header().algorithm;
352
353        let data = match algorithm {
354            RustCompressionAlgorithm::Lz4 => {
355                let decompressor = Lz4Decompressor::new();
356                self.reader
357                    .decompress_block_validated(block_idx, &decompressor)
358                    .map_err(|e| PyIOError::new_err(format!("Block decompression failed: {}", e)))?
359            }
360            RustCompressionAlgorithm::Zstd => {
361                let decompressor = ZstdDecompressor::new();
362                self.reader
363                    .decompress_block_validated(block_idx, &decompressor)
364                    .map_err(|e| PyIOError::new_err(format!("Block decompression failed: {}", e)))?
365            }
366        };
367
368        Ok(data.into_pyarray_bound(py))
369    }
370
371    /// Validate all block checksums (V2 only).
372    fn validate_checksums(&mut self) -> PyResult<()> {
373        self.reader
374            .validate_checksums()
375            .map_err(|e| PyValueError::new_err(format!("Checksum validation failed: {}", e)))
376    }
377
378    fn __repr__(&self) -> String {
379        format!("HctReader('{}', blocks={})", self.path, self.num_blocks())
380    }
381}
382
383// ============================================================================
384// HCT Writer
385// ============================================================================
386
387/// Writer for HCT (Haagenti Compressed Tensor) files.
388#[pyclass]
389pub struct HctWriter {
390    writer: Option<HctWriterV2<BufWriter<File>>>,
391    path: String,
392    algorithm: CompressionAlgorithm,
393}
394
395#[pymethods]
396impl HctWriter {
397    /// Create a new HCT file for writing.
398    #[new]
399    #[pyo3(signature = (path, algorithm, dtype, shape, block_size=None))]
400    fn new(
401        path: &str,
402        algorithm: CompressionAlgorithm,
403        dtype: DType,
404        shape: Vec<u64>,
405        block_size: Option<u32>,
406    ) -> PyResult<Self> {
407        let file = File::create(path)
408            .map_err(|e| PyIOError::new_err(format!("Failed to create {}: {}", path, e)))?;
409        let buf_writer = BufWriter::new(file);
410
411        let mut writer = RustHctWriterV2::new(buf_writer, algorithm.into(), dtype.into(), shape);
412
413        if let Some(bs) = block_size {
414            writer = writer.with_block_size(bs);
415        }
416
417        Ok(HctWriter {
418            writer: Some(writer),
419            path: path.to_string(),
420            algorithm,
421        })
422    }
423
424    /// Compress and write data from a numpy array.
425    fn compress_data(&mut self, data: PyReadonlyArray1<f32>) -> PyResult<()> {
426        let writer = self
427            .writer
428            .as_mut()
429            .ok_or_else(|| PyValueError::new_err("Writer already finalized"))?;
430
431        let slice = data.as_slice()?;
432        let bytes: Vec<u8> = slice.iter().flat_map(|f| f.to_le_bytes()).collect();
433
434        match self.algorithm {
435            CompressionAlgorithm::Lz4 => {
436                let compressor = Lz4Compressor::new();
437                writer
438                    .compress_data(&bytes, &compressor)
439                    .map_err(|e| PyIOError::new_err(format!("Compression failed: {}", e)))?;
440            }
441            CompressionAlgorithm::Zstd => {
442                let compressor = ZstdCompressor::new(); // Level 3 compression
443                writer
444                    .compress_data(&bytes, &compressor)
445                    .map_err(|e| PyIOError::new_err(format!("Compression failed: {}", e)))?;
446            }
447        }
448
449        Ok(())
450    }
451
452    /// Finalize the file and flush to disk.
453    fn finish(&mut self) -> PyResult<()> {
454        let writer = self
455            .writer
456            .take()
457            .ok_or_else(|| PyValueError::new_err("Writer already finalized"))?;
458
459        writer
460            .finish()
461            .map_err(|e| PyIOError::new_err(format!("Failed to finalize: {}", e)))
462    }
463
464    fn __repr__(&self) -> String {
465        format!("HctWriter('{}')", self.path)
466    }
467}
468
469// ============================================================================
470// HoloTensor Encoder (Phase 4 - Progressive Loading)
471// ============================================================================
472
473/// Encoder for HoloTensor (progressive tensor loading).
474///
475/// Supports three encoding schemes:
476/// - Spectral: DCT-based, best for smooth weights (attention, embeddings)
477/// - RandomProjection: Johnson-Lindenstrauss, good for dense layers
478/// - LowRankDistributed: SVD-based, best for low-rank matrices
479///
480/// Example:
481///     encoder = HoloTensorEncoder(HolographicEncoding.Spectral, n_fragments=8)
482///     header_bytes, fragment_list = encoder.encode_2d(weights, 4096, 4096)
483#[pyclass]
484pub struct HoloTensorEncoder {
485    encoder: haagenti::HoloTensorEncoder,
486    encoding: HolographicEncoding,
487    n_fragments: u16,
488}
489
490#[pymethods]
491impl HoloTensorEncoder {
492    /// Create a new HoloTensor encoder.
493    ///
494    /// Args:
495    ///     encoding: Holographic encoding scheme
496    ///     n_fragments: Number of fragments to create (default 8)
497    ///     seed: Random seed for deterministic encoding
498    ///     essential_ratio: Ratio of essential data in first fragment (0.01-0.5)
499    ///     max_rank: Maximum rank for LRDF encoding
500    #[new]
501    #[pyo3(signature = (encoding, n_fragments=None, seed=None, essential_ratio=None, max_rank=None))]
502    fn new(
503        encoding: HolographicEncoding,
504        n_fragments: Option<u16>,
505        seed: Option<u64>,
506        essential_ratio: Option<f32>,
507        max_rank: Option<usize>,
508    ) -> Self {
509        let n_frags = n_fragments.unwrap_or(8);
510        let mut encoder = haagenti::HoloTensorEncoder::new(encoding.into()).with_fragments(n_frags);
511
512        if let Some(s) = seed {
513            encoder = encoder.with_seed(s);
514        }
515        if let Some(r) = essential_ratio {
516            encoder = encoder.with_essential_ratio(r);
517        }
518        if let Some(r) = max_rank {
519            encoder = encoder.with_max_rank(r);
520        }
521
522        HoloTensorEncoder {
523            encoder,
524            encoding,
525            n_fragments: n_frags,
526        }
527    }
528
529    /// Encode a 2D tensor (matrix) into holographic fragments.
530    ///
531    /// Args:
532    ///     data: Flattened tensor data (float32)
533    ///     rows: Number of rows
534    ///     cols: Number of columns
535    ///
536    /// Returns:
537    ///     Tuple of (header_bytes, list of fragment_bytes)
538    fn encode_2d(
539        &self,
540        data: PyReadonlyArray1<f32>,
541        rows: usize,
542        cols: usize,
543    ) -> PyResult<(HoloTensorHeaderPy, Vec<HoloFragmentPy>)> {
544        let slice = data.as_slice()?;
545
546        if slice.len() != rows * cols {
547            return Err(PyValueError::new_err(format!(
548                "Data length {} doesn't match {}x{}={}",
549                slice.len(),
550                rows,
551                cols,
552                rows * cols
553            )));
554        }
555
556        let (header, fragments) = self
557            .encoder
558            .encode_2d(slice, rows, cols)
559            .map_err(|e| PyValueError::new_err(format!("Encoding failed: {}", e)))?;
560
561        // Convert to Python types
562        let header_py = HoloTensorHeaderPy::from(&header);
563        let fragments_py: Vec<HoloFragmentPy> =
564            fragments.into_iter().map(HoloFragmentPy::from).collect();
565
566        Ok((header_py, fragments_py))
567    }
568
569    /// Encode a 1D tensor (vector) into holographic fragments.
570    fn encode_1d(
571        &self,
572        data: PyReadonlyArray1<f32>,
573    ) -> PyResult<(HoloTensorHeaderPy, Vec<HoloFragmentPy>)> {
574        let slice = data.as_slice()?;
575
576        let (header, fragments) = self
577            .encoder
578            .encode_1d(slice)
579            .map_err(|e| PyValueError::new_err(format!("Encoding failed: {}", e)))?;
580
581        let header_py = HoloTensorHeaderPy::from(&header);
582        let fragments_py: Vec<HoloFragmentPy> =
583            fragments.into_iter().map(HoloFragmentPy::from).collect();
584
585        Ok((header_py, fragments_py))
586    }
587
588    /// Get the encoding scheme.
589    #[getter]
590    fn encoding(&self) -> HolographicEncoding {
591        self.encoding
592    }
593
594    /// Get the number of fragments.
595    #[getter]
596    fn n_fragments(&self) -> u16 {
597        self.n_fragments
598    }
599
600    fn __repr__(&self) -> String {
601        format!(
602            "HoloTensorEncoder(encoding={:?}, n_fragments={})",
603            self.encoding, self.n_fragments
604        )
605    }
606}
607
608// ============================================================================
609// HoloTensor Decoder (Phase 4 - Progressive Loading)
610// ============================================================================
611
612/// Decoder for HoloTensor with progressive reconstruction.
613///
614/// Allows loading fragments incrementally and reconstructing
615/// the tensor at any quality level. Quality improves as more
616/// fragments are added.
617///
618/// Example:
619/// ```python
620/// decoder = HoloTensorDecoder(header)
621/// decoder.add_fragment(fragments[0])  # ~30% quality
622/// decoder.add_fragment(fragments[1])  # ~50% quality
623/// weights = decoder.reconstruct()
624/// ```
625#[pyclass]
626pub struct HoloTensorDecoder {
627    decoder: haagenti::HoloTensorDecoder,
628    header: HoloTensorHeaderPy,
629}
630
631#[pymethods]
632impl HoloTensorDecoder {
633    /// Create a decoder from a header.
634    #[new]
635    fn new(header: HoloTensorHeaderPy) -> PyResult<Self> {
636        let rust_header = header.to_rust_header()?;
637        Ok(HoloTensorDecoder {
638            decoder: haagenti::HoloTensorDecoder::new(rust_header),
639            header,
640        })
641    }
642
643    /// Add a fragment to the reconstruction.
644    ///
645    /// Returns the new quality level (0.0-1.0).
646    fn add_fragment(&mut self, fragment: &HoloFragmentPy) -> PyResult<f32> {
647        let rust_fragment = fragment.to_rust_fragment();
648        self.decoder
649            .add_fragment(rust_fragment)
650            .map_err(|e| PyValueError::new_err(format!("Failed to add fragment: {}", e)))
651    }
652
653    /// Current reconstruction quality (0.0-1.0).
654    ///
655    /// Quality represents how close the reconstruction is to the original.
656    /// 1.0 means perfect reconstruction (all fragments loaded).
657    #[getter]
658    fn quality(&self) -> f32 {
659        self.decoder.quality()
660    }
661
662    /// Number of fragments loaded so far.
663    #[getter]
664    fn fragments_loaded(&self) -> u16 {
665        self.decoder.fragments_loaded()
666    }
667
668    /// Total number of fragments.
669    #[getter]
670    fn total_fragments(&self) -> u16 {
671        self.header.total_fragments
672    }
673
674    /// Check if minimum fragments for reconstruction are loaded.
675    fn can_reconstruct(&self) -> bool {
676        self.decoder.can_reconstruct()
677    }
678
679    /// Reconstruct the tensor from loaded fragments.
680    ///
681    /// Returns a numpy array with the reconstructed weights.
682    /// Quality depends on how many fragments have been loaded.
683    fn reconstruct<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<f32>>> {
684        let data = self
685            .decoder
686            .reconstruct()
687            .map_err(|e| PyValueError::new_err(format!("Reconstruction failed: {}", e)))?;
688        Ok(data.into_pyarray_bound(py))
689    }
690
691    /// Get the header.
692    #[getter]
693    fn header(&self) -> HoloTensorHeaderPy {
694        self.header.clone()
695    }
696
697    fn __repr__(&self) -> String {
698        format!(
699            "HoloTensorDecoder(quality={:.1}%, fragments={}/{})",
700            self.quality() * 100.0,
701            self.fragments_loaded(),
702            self.total_fragments()
703        )
704    }
705}
706
707// ============================================================================
708// HoloTensor Header (Python wrapper)
709// ============================================================================
710
711/// Header for a HoloTensor file.
712#[pyclass]
713#[derive(Clone)]
714pub struct HoloTensorHeaderPy {
715    #[pyo3(get)]
716    pub encoding: HolographicEncoding,
717    #[pyo3(get)]
718    pub total_fragments: u16,
719    #[pyo3(get)]
720    pub min_fragments: u16,
721    #[pyo3(get)]
722    pub shape: Vec<u64>,
723    #[pyo3(get)]
724    pub original_size: u64,
725    #[pyo3(get)]
726    pub seed: u64,
727}
728
729impl From<&haagenti::HoloTensorHeader> for HoloTensorHeaderPy {
730    fn from(h: &haagenti::HoloTensorHeader) -> Self {
731        HoloTensorHeaderPy {
732            encoding: h.encoding.into(),
733            total_fragments: h.total_fragments,
734            min_fragments: h.min_fragments,
735            shape: h.shape.clone(),
736            original_size: h.original_size,
737            seed: h.seed,
738        }
739    }
740}
741
742impl HoloTensorHeaderPy {
743    fn to_rust_header(&self) -> PyResult<haagenti::HoloTensorHeader> {
744        Ok(haagenti::HoloTensorHeader {
745            encoding: self.encoding.into(),
746            compression: RustCompressionAlgorithm::Lz4,
747            flags: 0,
748            total_fragments: self.total_fragments,
749            min_fragments: self.min_fragments,
750            original_size: self.original_size,
751            seed: self.seed,
752            dtype: RustDType::F32,
753            shape: self.shape.clone(),
754            quality_curve: haagenti::QualityCurve::default(),
755            quantization: None,
756        })
757    }
758}
759
760#[pymethods]
761impl HoloTensorHeaderPy {
762    fn __repr__(&self) -> String {
763        format!(
764            "HoloTensorHeader(encoding={:?}, fragments={}, shape={:?})",
765            self.encoding, self.total_fragments, self.shape
766        )
767    }
768}
769
770// ============================================================================
771// HoloTensor Fragment (Python wrapper)
772// ============================================================================
773
774/// A fragment of a HoloTensor.
775///
776/// Each fragment contains information about the whole tensor.
777/// Any subset of fragments can reconstruct an approximation.
778#[pyclass]
779#[derive(Clone)]
780pub struct HoloFragmentPy {
781    #[pyo3(get)]
782    pub index: u16,
783    #[pyo3(get)]
784    pub flags: u16,
785    #[pyo3(get)]
786    pub checksum: u64,
787    data: Vec<u8>,
788}
789
790impl From<haagenti::HoloFragment> for HoloFragmentPy {
791    fn from(f: haagenti::HoloFragment) -> Self {
792        HoloFragmentPy {
793            index: f.index,
794            flags: f.flags,
795            checksum: f.checksum,
796            data: f.data,
797        }
798    }
799}
800
801impl HoloFragmentPy {
802    fn to_rust_fragment(&self) -> haagenti::HoloFragment {
803        haagenti::HoloFragment {
804            index: self.index,
805            flags: self.flags,
806            checksum: self.checksum,
807            data: self.data.clone(),
808        }
809    }
810}
811
812#[pymethods]
813impl HoloFragmentPy {
814    /// Get fragment data as bytes.
815    fn data<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<u8>> {
816        self.data.clone().into_pyarray_bound(py)
817    }
818
819    /// Size of fragment data in bytes.
820    #[getter]
821    fn size(&self) -> usize {
822        self.data.len()
823    }
824
825    fn __repr__(&self) -> String {
826        format!(
827            "HoloFragment(index={}, size={})",
828            self.index,
829            self.data.len()
830        )
831    }
832}
833
834// ============================================================================
835// Utility Functions
836// ============================================================================
837
838/// Convert safetensors file to HCT format.
839#[pyfunction]
840#[pyo3(signature = (input_path, output_path, algorithm=CompressionAlgorithm::Lz4))]
841fn convert_safetensors_to_hct(
842    input_path: &str,
843    output_path: &str,
844    algorithm: CompressionAlgorithm,
845) -> PyResult<(u64, u64, f64)> {
846    // Read safetensors file
847    let data = std::fs::read(input_path)
848        .map_err(|e| PyIOError::new_err(format!("Failed to read {}: {}", input_path, e)))?;
849
850    // Parse safetensors header (JSON + tensors)
851    // For now, we just compress the raw bytes - real implementation would parse properly
852    let original_size = data.len() as u64;
853
854    // Create HCT writer
855    let file = File::create(output_path)
856        .map_err(|e| PyIOError::new_err(format!("Failed to create {}: {}", output_path, e)))?;
857    let buf_writer = BufWriter::new(file);
858
859    let mut writer = RustHctWriterV2::new(
860        buf_writer,
861        algorithm.into(),
862        RustDType::F32,          // Default to F32 for safetensors
863        vec![data.len() as u64], // Treat as 1D for raw conversion
864    );
865
866    match algorithm {
867        CompressionAlgorithm::Lz4 => {
868            let compressor = Lz4Compressor::new();
869            writer
870                .compress_data(&data, &compressor)
871                .map_err(|e| PyIOError::new_err(format!("Compression failed: {}", e)))?;
872        }
873        CompressionAlgorithm::Zstd => {
874            let compressor = ZstdCompressor::new();
875            writer
876                .compress_data(&data, &compressor)
877                .map_err(|e| PyIOError::new_err(format!("Compression failed: {}", e)))?;
878        }
879    }
880
881    writer
882        .finish()
883        .map_err(|e| PyIOError::new_err(format!("Failed to finalize: {}", e)))?;
884
885    // Get compressed size
886    let compressed_size = std::fs::metadata(output_path)
887        .map_err(|e| PyIOError::new_err(format!("Failed to stat {}: {}", output_path, e)))?
888        .len();
889
890    let ratio = original_size as f64 / compressed_size as f64;
891
892    Ok((original_size, compressed_size, ratio))
893}
894
895/// Get version information.
896#[pyfunction]
897fn version() -> String {
898    env!("CARGO_PKG_VERSION").to_string()
899}
900
901// ============================================================================
902// Top-Level Compression Functions (C.5)
903// ============================================================================
904
905/// Compress data using the specified algorithm.
906///
907/// Args:
908///     data: Bytes to compress
909///     algorithm: Compression algorithm ("zstd" or "lz4")
910///     level: Compression level ("fast", "default", or "best")
911///     dictionary: Optional ZstdDict for dictionary compression
912///
913/// Returns:
914///     Compressed bytes
915#[pyfunction]
916#[pyo3(signature = (data, algorithm="zstd", level="default", dictionary=None))]
917fn compress(
918    data: &[u8],
919    algorithm: &str,
920    level: &str,
921    dictionary: Option<&ZstdDict>,
922) -> PyResult<Vec<u8>> {
923    let _ = dictionary; // Dictionary support is future work
924
925    let compression_level = match level {
926        "fast" => haagenti_core::CompressionLevel::Fast,
927        "default" => haagenti_core::CompressionLevel::Default,
928        "best" => haagenti_core::CompressionLevel::Best,
929        _ => return Err(PyValueError::new_err(format!("Invalid level: {}", level))),
930    };
931
932    match algorithm.to_lowercase().as_str() {
933        "zstd" => {
934            let compressor = ZstdCompressor::with_level(compression_level);
935            compressor
936                .compress(data)
937                .map_err(|e| PyValueError::new_err(format!("Zstd compression failed: {}", e)))
938        }
939        "lz4" => {
940            let compressor = Lz4Compressor::new();
941            compressor
942                .compress(data)
943                .map_err(|e| PyValueError::new_err(format!("LZ4 compression failed: {}", e)))
944        }
945        _ => Err(PyValueError::new_err(format!(
946            "Invalid algorithm: {}. Use 'zstd' or 'lz4'",
947            algorithm
948        ))),
949    }
950}
951
952/// Decompress data using the specified algorithm.
953///
954/// Args:
955///     data: Compressed bytes
956///     algorithm: Compression algorithm ("zstd" or "lz4")
957///
958/// Returns:
959///     Decompressed bytes
960#[pyfunction]
961#[pyo3(signature = (data, algorithm="zstd"))]
962fn decompress(data: &[u8], algorithm: &str) -> PyResult<Vec<u8>> {
963    match algorithm.to_lowercase().as_str() {
964        "zstd" => {
965            let decompressor = ZstdDecompressor::new();
966            decompressor.decompress(data).map_err(|e| {
967                DecompressionError::new_err(format!("Zstd decompression failed: {}", e))
968            })
969        }
970        "lz4" => {
971            let decompressor = Lz4Decompressor::new();
972            decompressor.decompress(data).map_err(|e| {
973                DecompressionError::new_err(format!("LZ4 decompression failed: {}", e))
974            })
975        }
976        _ => Err(PyValueError::new_err(format!(
977            "Invalid algorithm: {}. Use 'zstd' or 'lz4'",
978            algorithm
979        ))),
980    }
981}
982
983// ============================================================================
984// Zstd Dictionary Support (C.5)
985// ============================================================================
986
987/// A trained Zstd dictionary for improved compression.
988#[pyclass]
989#[derive(Clone)]
990pub struct ZstdDict {
991    id: u32,
992    data: Vec<u8>,
993}
994
995#[pymethods]
996impl ZstdDict {
997    /// Train a dictionary from sample data.
998    ///
999    /// Args:
1000    ///     samples: List of bytes samples to train on
1001    ///     max_size: Maximum dictionary size in bytes
1002    ///
1003    /// Returns:
1004    ///     Trained ZstdDict
1005    #[staticmethod]
1006    #[pyo3(signature = (samples, max_size=8192))]
1007    fn train(samples: Vec<Vec<u8>>, max_size: usize) -> PyResult<Self> {
1008        use haagenti_zstd::ZstdDictionary;
1009
1010        if samples.len() < 5 {
1011            return Err(PyValueError::new_err(
1012                "Need at least 5 samples for dictionary training",
1013            ));
1014        }
1015
1016        let sample_refs: Vec<&[u8]> = samples.iter().map(|s| s.as_slice()).collect();
1017        let dict = ZstdDictionary::train(&sample_refs, max_size)
1018            .map_err(|e| PyValueError::new_err(format!("Dictionary training failed: {}", e)))?;
1019
1020        Ok(ZstdDict {
1021            id: dict.id(),
1022            data: dict.serialize(),
1023        })
1024    }
1025
1026    /// Dictionary ID.
1027    #[getter]
1028    fn id(&self) -> u32 {
1029        self.id
1030    }
1031
1032    /// Get dictionary as bytes.
1033    fn as_bytes(&self) -> Vec<u8> {
1034        self.data.clone()
1035    }
1036
1037    fn __repr__(&self) -> String {
1038        format!("ZstdDict(id={}, size={})", self.id, self.data.len())
1039    }
1040}
1041
1042// ============================================================================
1043// Streaming Encoder/Decoder (C.5)
1044// ============================================================================
1045
1046/// Streaming encoder for incremental compression.
1047#[pyclass]
1048pub struct StreamingEncoder {
1049    algorithm: String,
1050    buffer: Vec<u8>,
1051}
1052
1053#[pymethods]
1054impl StreamingEncoder {
1055    /// Create a new streaming encoder.
1056    #[new]
1057    fn new(algorithm: &str) -> PyResult<Self> {
1058        match algorithm.to_lowercase().as_str() {
1059            "zstd" | "lz4" => Ok(StreamingEncoder {
1060                algorithm: algorithm.to_lowercase(),
1061                buffer: Vec::new(),
1062            }),
1063            _ => Err(PyValueError::new_err(format!(
1064                "Invalid algorithm: {}",
1065                algorithm
1066            ))),
1067        }
1068    }
1069
1070    /// Write data to the encoder.
1071    fn write(&mut self, data: &[u8]) -> PyResult<()> {
1072        self.buffer.extend_from_slice(data);
1073        Ok(())
1074    }
1075
1076    /// Finish encoding and return compressed data.
1077    fn finish(&mut self) -> PyResult<Vec<u8>> {
1078        let result = match self.algorithm.as_str() {
1079            "zstd" => {
1080                let compressor = ZstdCompressor::new();
1081                compressor
1082                    .compress(&self.buffer)
1083                    .map_err(|e| PyValueError::new_err(format!("Compression failed: {}", e)))?
1084            }
1085            "lz4" => {
1086                let compressor = Lz4Compressor::new();
1087                compressor
1088                    .compress(&self.buffer)
1089                    .map_err(|e| PyValueError::new_err(format!("Compression failed: {}", e)))?
1090            }
1091            _ => return Err(PyValueError::new_err("Invalid algorithm")),
1092        };
1093        self.buffer.clear();
1094        Ok(result)
1095    }
1096
1097    fn __enter__(slf: PyRef<Self>) -> PyRef<Self> {
1098        slf
1099    }
1100
1101    fn __exit__(
1102        &mut self,
1103        _exc_type: Option<PyObject>,
1104        _exc_val: Option<PyObject>,
1105        _exc_tb: Option<PyObject>,
1106    ) -> bool {
1107        false
1108    }
1109}
1110
1111/// Streaming decoder for incremental decompression.
1112#[pyclass]
1113pub struct StreamingDecoder {
1114    algorithm: String,
1115    buffer: Vec<u8>,
1116}
1117
1118#[pymethods]
1119impl StreamingDecoder {
1120    /// Create a new streaming decoder.
1121    #[new]
1122    fn new(algorithm: &str) -> PyResult<Self> {
1123        match algorithm.to_lowercase().as_str() {
1124            "zstd" | "lz4" => Ok(StreamingDecoder {
1125                algorithm: algorithm.to_lowercase(),
1126                buffer: Vec::new(),
1127            }),
1128            _ => Err(PyValueError::new_err(format!(
1129                "Invalid algorithm: {}",
1130                algorithm
1131            ))),
1132        }
1133    }
1134
1135    /// Write compressed data to the decoder.
1136    ///
1137    /// Returns any decompressed data available (may be empty).
1138    fn write(&mut self, data: &[u8]) -> PyResult<Option<Vec<u8>>> {
1139        self.buffer.extend_from_slice(data);
1140        Ok(None) // Streaming decompression returns data on finish
1141    }
1142
1143    /// Finish decoding and return remaining data.
1144    fn finish(&mut self) -> PyResult<Vec<u8>> {
1145        let result = match self.algorithm.as_str() {
1146            "zstd" => {
1147                let decompressor = ZstdDecompressor::new();
1148                decompressor.decompress(&self.buffer).map_err(|e| {
1149                    DecompressionError::new_err(format!("Decompression failed: {}", e))
1150                })?
1151            }
1152            "lz4" => {
1153                let decompressor = Lz4Decompressor::new();
1154                decompressor.decompress(&self.buffer).map_err(|e| {
1155                    DecompressionError::new_err(format!("Decompression failed: {}", e))
1156                })?
1157            }
1158            _ => return Err(PyValueError::new_err("Invalid algorithm")),
1159        };
1160        self.buffer.clear();
1161        Ok(result)
1162    }
1163
1164    fn __enter__(slf: PyRef<Self>) -> PyRef<Self> {
1165        slf
1166    }
1167
1168    fn __exit__(
1169        &mut self,
1170        _exc_type: Option<PyObject>,
1171        _exc_val: Option<PyObject>,
1172        _exc_tb: Option<PyObject>,
1173    ) -> bool {
1174        false
1175    }
1176}
1177
1178// ============================================================================
1179// Custom Exception Types (C.5)
1180// ============================================================================
1181
1182pyo3::create_exception!(haagenti, DecompressionError, pyo3::exceptions::PyException);
1183
1184// ============================================================================
1185// Helper Functions
1186// ============================================================================
1187
1188fn bytes_to_f32(data: &[u8], dtype: RustDType) -> PyResult<Vec<f32>> {
1189    match dtype {
1190        RustDType::F32 => {
1191            if data.len() % 4 != 0 {
1192                return Err(PyValueError::new_err("Invalid F32 data length"));
1193            }
1194            Ok(data
1195                .chunks_exact(4)
1196                .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
1197                .collect())
1198        }
1199        RustDType::F16 => {
1200            if data.len() % 2 != 0 {
1201                return Err(PyValueError::new_err("Invalid F16 data length"));
1202            }
1203            Ok(data
1204                .chunks_exact(2)
1205                .map(|b| {
1206                    let bits = u16::from_le_bytes([b[0], b[1]]);
1207                    half::f16::from_bits(bits).to_f32()
1208                })
1209                .collect())
1210        }
1211        RustDType::BF16 => {
1212            if data.len() % 2 != 0 {
1213                return Err(PyValueError::new_err("Invalid BF16 data length"));
1214            }
1215            Ok(data
1216                .chunks_exact(2)
1217                .map(|b| {
1218                    let bits = u16::from_le_bytes([b[0], b[1]]);
1219                    half::bf16::from_bits(bits).to_f32()
1220                })
1221                .collect())
1222        }
1223        RustDType::I8 => Ok(data.iter().map(|&b| b as i8 as f32).collect()),
1224        RustDType::I4 => {
1225            // Unpack 4-bit values
1226            Ok(data
1227                .iter()
1228                .flat_map(|&b| {
1229                    let lo = (b & 0x0F) as i8;
1230                    let hi = ((b >> 4) & 0x0F) as i8;
1231                    // Sign-extend 4-bit to 8-bit
1232                    let lo = if lo & 0x08 != 0 {
1233                        lo | 0xF0u8 as i8
1234                    } else {
1235                        lo
1236                    };
1237                    let hi = if hi & 0x08 != 0 {
1238                        hi | 0xF0u8 as i8
1239                    } else {
1240                        hi
1241                    };
1242                    vec![lo as f32, hi as f32]
1243                })
1244                .collect())
1245        }
1246    }
1247}
1248
1249// ============================================================================
1250// Python Module
1251// ============================================================================
1252
1253/// Haagenti Python bindings for tensor compression.
1254#[pymodule]
1255fn _haagenti_python(m: &Bound<'_, PyModule>) -> PyResult<()> {
1256    // Enums
1257    m.add_class::<CompressionAlgorithm>()?;
1258    m.add_class::<DType>()?;
1259    m.add_class::<QuantizationScheme>()?;
1260    m.add_class::<HolographicEncoding>()?;
1261
1262    // HCT classes
1263    m.add_class::<HctHeader>()?;
1264    m.add_class::<HctReader>()?;
1265    m.add_class::<HctWriter>()?;
1266
1267    // HoloTensor classes
1268    m.add_class::<HoloTensorEncoder>()?;
1269    m.add_class::<HoloTensorDecoder>()?;
1270    m.add_class::<HoloTensorHeaderPy>()?;
1271    m.add_class::<HoloFragmentPy>()?;
1272
1273    // C.5: Compression classes
1274    m.add_class::<ZstdDict>()?;
1275    m.add_class::<StreamingEncoder>()?;
1276    m.add_class::<StreamingDecoder>()?;
1277
1278    // Functions
1279    m.add_function(wrap_pyfunction!(convert_safetensors_to_hct, m)?)?;
1280    m.add_function(wrap_pyfunction!(version, m)?)?;
1281    m.add_function(wrap_pyfunction!(compress, m)?)?;
1282    m.add_function(wrap_pyfunction!(decompress, m)?)?;
1283
1284    // C.5: Custom exceptions
1285    m.add(
1286        "DecompressionError",
1287        m.py().get_type_bound::<DecompressionError>(),
1288    )?;
1289
1290    // Create streaming submodule
1291    let streaming = PyModule::new_bound(m.py(), "streaming")?;
1292    streaming.add_class::<StreamingEncoder>()?;
1293    streaming.add_class::<StreamingDecoder>()?;
1294    m.add_submodule(&streaming)?;
1295
1296    Ok(())
1297}