1#![allow(deprecated)]
3#![allow(clippy::manual_div_ceil)]
5#![allow(unused_imports)]
7#![allow(clippy::useless_conversion)]
9#![allow(unexpected_cfgs)]
11
12use numpy::{IntoPyArray, PyArray1, PyReadonlyArray1, ToPyArray};
33use pyo3::exceptions::{PyIOError, PyValueError};
34use pyo3::prelude::*;
35use std::fs::File;
36use std::io::{BufReader, BufWriter};
37
38use haagenti::{
40 CompressionAlgorithm as RustCompressionAlgorithm, Compressor, DType as RustDType, Decompressor,
41 HctHeader as RustHctHeader, HctReaderV2, HctWriterV2,
42 HolographicEncoding as RustHolographicEncoding, QuantizationScheme as RustQuantizationScheme,
43};
44
45type RustHctReaderV2<R> = HctReaderV2<R>;
47type RustHctWriterV2<W> = HctWriterV2<W>;
48
49use haagenti_lz4::{Lz4Compressor, Lz4Decompressor};
50use haagenti_zstd::{ZstdCompressor, ZstdDecompressor};
51
52#[pyclass]
58#[derive(Clone, Copy, Debug, PartialEq)]
59pub enum CompressionAlgorithm {
60 Lz4,
61 Zstd,
62}
63
64impl From<RustCompressionAlgorithm> for CompressionAlgorithm {
65 fn from(algo: RustCompressionAlgorithm) -> Self {
66 match algo {
67 RustCompressionAlgorithm::Lz4 => CompressionAlgorithm::Lz4,
68 RustCompressionAlgorithm::Zstd => CompressionAlgorithm::Zstd,
69 }
70 }
71}
72
73impl From<CompressionAlgorithm> for RustCompressionAlgorithm {
74 fn from(algo: CompressionAlgorithm) -> Self {
75 match algo {
76 CompressionAlgorithm::Lz4 => RustCompressionAlgorithm::Lz4,
77 CompressionAlgorithm::Zstd => RustCompressionAlgorithm::Zstd,
78 }
79 }
80}
81
82#[pymethods]
83impl CompressionAlgorithm {
84 fn __repr__(&self) -> String {
85 format!("CompressionAlgorithm.{:?}", self)
86 }
87}
88
89#[pyclass]
91#[derive(Clone, Copy, Debug, PartialEq)]
92pub enum DType {
93 F32,
94 F16,
95 BF16,
96 I8,
97 I4,
98}
99
100impl From<RustDType> for DType {
101 fn from(dtype: RustDType) -> Self {
102 match dtype {
103 RustDType::F32 => DType::F32,
104 RustDType::F16 => DType::F16,
105 RustDType::BF16 => DType::BF16,
106 RustDType::I8 => DType::I8,
107 RustDType::I4 => DType::I4,
108 }
109 }
110}
111
112impl From<DType> for RustDType {
113 fn from(dtype: DType) -> Self {
114 match dtype {
115 DType::F32 => RustDType::F32,
116 DType::F16 => RustDType::F16,
117 DType::BF16 => RustDType::BF16,
118 DType::I8 => RustDType::I8,
119 DType::I4 => RustDType::I4,
120 }
121 }
122}
123
124#[pymethods]
125impl DType {
126 fn bits(&self) -> u32 {
128 match self {
129 DType::F32 => 32,
130 DType::F16 | DType::BF16 => 16,
131 DType::I8 => 8,
132 DType::I4 => 4,
133 }
134 }
135
136 fn bytes(&self) -> u32 {
138 (self.bits() + 7) / 8
139 }
140
141 fn __repr__(&self) -> String {
142 format!("DType.{:?}", self)
143 }
144}
145
146#[pyclass]
148#[derive(Clone, Copy, Debug, PartialEq)]
149pub enum QuantizationScheme {
150 None,
151 GptqInt4,
152 AwqInt4,
153 SymmetricInt8,
154 AsymmetricInt8,
155}
156
157impl From<RustQuantizationScheme> for QuantizationScheme {
158 fn from(scheme: RustQuantizationScheme) -> Self {
159 match scheme {
160 RustQuantizationScheme::None => QuantizationScheme::None,
161 RustQuantizationScheme::GptqInt4 => QuantizationScheme::GptqInt4,
162 RustQuantizationScheme::AwqInt4 => QuantizationScheme::AwqInt4,
163 RustQuantizationScheme::SymmetricInt8 => QuantizationScheme::SymmetricInt8,
164 RustQuantizationScheme::AsymmetricInt8 => QuantizationScheme::AsymmetricInt8,
165 }
166 }
167}
168
169#[pymethods]
170impl QuantizationScheme {
171 fn __repr__(&self) -> String {
172 format!("QuantizationScheme.{:?}", self)
173 }
174}
175
176#[pyclass]
178#[derive(Clone, Copy, Debug, PartialEq)]
179pub enum HolographicEncoding {
180 Spectral,
182 RandomProjection,
184 LowRankDistributed,
186}
187
188impl From<RustHolographicEncoding> for HolographicEncoding {
189 fn from(enc: RustHolographicEncoding) -> Self {
190 match enc {
191 RustHolographicEncoding::Spectral => HolographicEncoding::Spectral,
192 RustHolographicEncoding::RandomProjection => HolographicEncoding::RandomProjection,
193 RustHolographicEncoding::LowRankDistributed => HolographicEncoding::LowRankDistributed,
194 }
195 }
196}
197
198impl From<HolographicEncoding> for RustHolographicEncoding {
199 fn from(enc: HolographicEncoding) -> Self {
200 match enc {
201 HolographicEncoding::Spectral => RustHolographicEncoding::Spectral,
202 HolographicEncoding::RandomProjection => RustHolographicEncoding::RandomProjection,
203 HolographicEncoding::LowRankDistributed => RustHolographicEncoding::LowRankDistributed,
204 }
205 }
206}
207
208#[pymethods]
209impl HolographicEncoding {
210 fn __repr__(&self) -> String {
211 format!("HolographicEncoding.{:?}", self)
212 }
213}
214
215#[pyclass]
221#[derive(Clone)]
222pub struct HctHeader {
223 #[pyo3(get)]
224 pub algorithm: CompressionAlgorithm,
225 #[pyo3(get)]
226 pub dtype: DType,
227 #[pyo3(get)]
228 pub shape: Vec<u64>,
229 #[pyo3(get)]
230 pub original_size: u64,
231 #[pyo3(get)]
232 pub compressed_size: u64,
233 #[pyo3(get)]
234 pub block_size: u32,
235 #[pyo3(get)]
236 pub num_blocks: u32,
237}
238
239impl From<&RustHctHeader> for HctHeader {
240 fn from(header: &RustHctHeader) -> Self {
241 HctHeader {
242 algorithm: header.algorithm.into(),
243 dtype: header.dtype.into(),
244 shape: header.shape.clone(),
245 original_size: header.original_size,
246 compressed_size: header.compressed_size,
247 block_size: header.block_size,
248 num_blocks: header.num_blocks,
249 }
250 }
251}
252
253#[pymethods]
254impl HctHeader {
255 fn numel(&self) -> u64 {
257 self.shape.iter().product()
258 }
259
260 fn compression_ratio(&self) -> f64 {
262 if self.compressed_size == 0 {
263 0.0
264 } else {
265 self.original_size as f64 / self.compressed_size as f64
266 }
267 }
268
269 fn __repr__(&self) -> String {
270 format!(
271 "HctHeader(dtype={:?}, shape={:?}, ratio={:.2}x)",
272 self.dtype,
273 self.shape,
274 self.compression_ratio()
275 )
276 }
277}
278
279#[pyclass]
287pub struct HctReader {
288 reader: HctReaderV2<BufReader<File>>,
289 path: String,
290}
291
292#[pymethods]
293impl HctReader {
294 #[new]
296 fn new(path: &str) -> PyResult<Self> {
297 let file = File::open(path)
298 .map_err(|e| PyIOError::new_err(format!("Failed to open {}: {}", path, e)))?;
299 let buf_reader = BufReader::new(file);
300 let reader = RustHctReaderV2::new(buf_reader)
301 .map_err(|e| PyIOError::new_err(format!("Failed to read HCT header: {}", e)))?;
302 Ok(HctReader {
303 reader,
304 path: path.to_string(),
305 })
306 }
307
308 fn header(&self) -> HctHeader {
310 HctHeader::from(self.reader.header())
311 }
312
313 fn num_blocks(&self) -> usize {
315 self.reader.num_blocks()
316 }
317
318 fn decompress_all<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<f32>>> {
320 let algorithm = self.reader.header().algorithm;
322 let dtype = self.reader.header().dtype;
323
324 let data = match algorithm {
326 RustCompressionAlgorithm::Lz4 => {
327 let decompressor = Lz4Decompressor::new();
328 self.reader
329 .decompress_all_validated(&decompressor)
330 .map_err(|e| PyIOError::new_err(format!("Decompression failed: {}", e)))?
331 }
332 RustCompressionAlgorithm::Zstd => {
333 let decompressor = ZstdDecompressor::new();
334 self.reader
335 .decompress_all_validated(&decompressor)
336 .map_err(|e| PyIOError::new_err(format!("Decompression failed: {}", e)))?
337 }
338 };
339
340 let floats = bytes_to_f32(&data, dtype)?;
342 Ok(floats.into_pyarray_bound(py))
343 }
344
345 fn decompress_block<'py>(
347 &mut self,
348 py: Python<'py>,
349 block_idx: usize,
350 ) -> PyResult<Bound<'py, PyArray1<u8>>> {
351 let algorithm = self.reader.header().algorithm;
352
353 let data = match algorithm {
354 RustCompressionAlgorithm::Lz4 => {
355 let decompressor = Lz4Decompressor::new();
356 self.reader
357 .decompress_block_validated(block_idx, &decompressor)
358 .map_err(|e| PyIOError::new_err(format!("Block decompression failed: {}", e)))?
359 }
360 RustCompressionAlgorithm::Zstd => {
361 let decompressor = ZstdDecompressor::new();
362 self.reader
363 .decompress_block_validated(block_idx, &decompressor)
364 .map_err(|e| PyIOError::new_err(format!("Block decompression failed: {}", e)))?
365 }
366 };
367
368 Ok(data.into_pyarray_bound(py))
369 }
370
371 fn validate_checksums(&mut self) -> PyResult<()> {
373 self.reader
374 .validate_checksums()
375 .map_err(|e| PyValueError::new_err(format!("Checksum validation failed: {}", e)))
376 }
377
378 fn __repr__(&self) -> String {
379 format!("HctReader('{}', blocks={})", self.path, self.num_blocks())
380 }
381}
382
383#[pyclass]
389pub struct HctWriter {
390 writer: Option<HctWriterV2<BufWriter<File>>>,
391 path: String,
392 algorithm: CompressionAlgorithm,
393}
394
395#[pymethods]
396impl HctWriter {
397 #[new]
399 #[pyo3(signature = (path, algorithm, dtype, shape, block_size=None))]
400 fn new(
401 path: &str,
402 algorithm: CompressionAlgorithm,
403 dtype: DType,
404 shape: Vec<u64>,
405 block_size: Option<u32>,
406 ) -> PyResult<Self> {
407 let file = File::create(path)
408 .map_err(|e| PyIOError::new_err(format!("Failed to create {}: {}", path, e)))?;
409 let buf_writer = BufWriter::new(file);
410
411 let mut writer = RustHctWriterV2::new(buf_writer, algorithm.into(), dtype.into(), shape);
412
413 if let Some(bs) = block_size {
414 writer = writer.with_block_size(bs);
415 }
416
417 Ok(HctWriter {
418 writer: Some(writer),
419 path: path.to_string(),
420 algorithm,
421 })
422 }
423
424 fn compress_data(&mut self, data: PyReadonlyArray1<f32>) -> PyResult<()> {
426 let writer = self
427 .writer
428 .as_mut()
429 .ok_or_else(|| PyValueError::new_err("Writer already finalized"))?;
430
431 let slice = data.as_slice()?;
432 let bytes: Vec<u8> = slice.iter().flat_map(|f| f.to_le_bytes()).collect();
433
434 match self.algorithm {
435 CompressionAlgorithm::Lz4 => {
436 let compressor = Lz4Compressor::new();
437 writer
438 .compress_data(&bytes, &compressor)
439 .map_err(|e| PyIOError::new_err(format!("Compression failed: {}", e)))?;
440 }
441 CompressionAlgorithm::Zstd => {
442 let compressor = ZstdCompressor::new(); writer
444 .compress_data(&bytes, &compressor)
445 .map_err(|e| PyIOError::new_err(format!("Compression failed: {}", e)))?;
446 }
447 }
448
449 Ok(())
450 }
451
452 fn finish(&mut self) -> PyResult<()> {
454 let writer = self
455 .writer
456 .take()
457 .ok_or_else(|| PyValueError::new_err("Writer already finalized"))?;
458
459 writer
460 .finish()
461 .map_err(|e| PyIOError::new_err(format!("Failed to finalize: {}", e)))
462 }
463
464 fn __repr__(&self) -> String {
465 format!("HctWriter('{}')", self.path)
466 }
467}
468
469#[pyclass]
484pub struct HoloTensorEncoder {
485 encoder: haagenti::HoloTensorEncoder,
486 encoding: HolographicEncoding,
487 n_fragments: u16,
488}
489
490#[pymethods]
491impl HoloTensorEncoder {
492 #[new]
501 #[pyo3(signature = (encoding, n_fragments=None, seed=None, essential_ratio=None, max_rank=None))]
502 fn new(
503 encoding: HolographicEncoding,
504 n_fragments: Option<u16>,
505 seed: Option<u64>,
506 essential_ratio: Option<f32>,
507 max_rank: Option<usize>,
508 ) -> Self {
509 let n_frags = n_fragments.unwrap_or(8);
510 let mut encoder = haagenti::HoloTensorEncoder::new(encoding.into()).with_fragments(n_frags);
511
512 if let Some(s) = seed {
513 encoder = encoder.with_seed(s);
514 }
515 if let Some(r) = essential_ratio {
516 encoder = encoder.with_essential_ratio(r);
517 }
518 if let Some(r) = max_rank {
519 encoder = encoder.with_max_rank(r);
520 }
521
522 HoloTensorEncoder {
523 encoder,
524 encoding,
525 n_fragments: n_frags,
526 }
527 }
528
529 fn encode_2d(
539 &self,
540 data: PyReadonlyArray1<f32>,
541 rows: usize,
542 cols: usize,
543 ) -> PyResult<(HoloTensorHeaderPy, Vec<HoloFragmentPy>)> {
544 let slice = data.as_slice()?;
545
546 if slice.len() != rows * cols {
547 return Err(PyValueError::new_err(format!(
548 "Data length {} doesn't match {}x{}={}",
549 slice.len(),
550 rows,
551 cols,
552 rows * cols
553 )));
554 }
555
556 let (header, fragments) = self
557 .encoder
558 .encode_2d(slice, rows, cols)
559 .map_err(|e| PyValueError::new_err(format!("Encoding failed: {}", e)))?;
560
561 let header_py = HoloTensorHeaderPy::from(&header);
563 let fragments_py: Vec<HoloFragmentPy> =
564 fragments.into_iter().map(HoloFragmentPy::from).collect();
565
566 Ok((header_py, fragments_py))
567 }
568
569 fn encode_1d(
571 &self,
572 data: PyReadonlyArray1<f32>,
573 ) -> PyResult<(HoloTensorHeaderPy, Vec<HoloFragmentPy>)> {
574 let slice = data.as_slice()?;
575
576 let (header, fragments) = self
577 .encoder
578 .encode_1d(slice)
579 .map_err(|e| PyValueError::new_err(format!("Encoding failed: {}", e)))?;
580
581 let header_py = HoloTensorHeaderPy::from(&header);
582 let fragments_py: Vec<HoloFragmentPy> =
583 fragments.into_iter().map(HoloFragmentPy::from).collect();
584
585 Ok((header_py, fragments_py))
586 }
587
588 #[getter]
590 fn encoding(&self) -> HolographicEncoding {
591 self.encoding
592 }
593
594 #[getter]
596 fn n_fragments(&self) -> u16 {
597 self.n_fragments
598 }
599
600 fn __repr__(&self) -> String {
601 format!(
602 "HoloTensorEncoder(encoding={:?}, n_fragments={})",
603 self.encoding, self.n_fragments
604 )
605 }
606}
607
608#[pyclass]
626pub struct HoloTensorDecoder {
627 decoder: haagenti::HoloTensorDecoder,
628 header: HoloTensorHeaderPy,
629}
630
631#[pymethods]
632impl HoloTensorDecoder {
633 #[new]
635 fn new(header: HoloTensorHeaderPy) -> PyResult<Self> {
636 let rust_header = header.to_rust_header()?;
637 Ok(HoloTensorDecoder {
638 decoder: haagenti::HoloTensorDecoder::new(rust_header),
639 header,
640 })
641 }
642
643 fn add_fragment(&mut self, fragment: &HoloFragmentPy) -> PyResult<f32> {
647 let rust_fragment = fragment.to_rust_fragment();
648 self.decoder
649 .add_fragment(rust_fragment)
650 .map_err(|e| PyValueError::new_err(format!("Failed to add fragment: {}", e)))
651 }
652
653 #[getter]
658 fn quality(&self) -> f32 {
659 self.decoder.quality()
660 }
661
662 #[getter]
664 fn fragments_loaded(&self) -> u16 {
665 self.decoder.fragments_loaded()
666 }
667
668 #[getter]
670 fn total_fragments(&self) -> u16 {
671 self.header.total_fragments
672 }
673
674 fn can_reconstruct(&self) -> bool {
676 self.decoder.can_reconstruct()
677 }
678
679 fn reconstruct<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<f32>>> {
684 let data = self
685 .decoder
686 .reconstruct()
687 .map_err(|e| PyValueError::new_err(format!("Reconstruction failed: {}", e)))?;
688 Ok(data.into_pyarray_bound(py))
689 }
690
691 #[getter]
693 fn header(&self) -> HoloTensorHeaderPy {
694 self.header.clone()
695 }
696
697 fn __repr__(&self) -> String {
698 format!(
699 "HoloTensorDecoder(quality={:.1}%, fragments={}/{})",
700 self.quality() * 100.0,
701 self.fragments_loaded(),
702 self.total_fragments()
703 )
704 }
705}
706
707#[pyclass]
713#[derive(Clone)]
714pub struct HoloTensorHeaderPy {
715 #[pyo3(get)]
716 pub encoding: HolographicEncoding,
717 #[pyo3(get)]
718 pub total_fragments: u16,
719 #[pyo3(get)]
720 pub min_fragments: u16,
721 #[pyo3(get)]
722 pub shape: Vec<u64>,
723 #[pyo3(get)]
724 pub original_size: u64,
725 #[pyo3(get)]
726 pub seed: u64,
727}
728
729impl From<&haagenti::HoloTensorHeader> for HoloTensorHeaderPy {
730 fn from(h: &haagenti::HoloTensorHeader) -> Self {
731 HoloTensorHeaderPy {
732 encoding: h.encoding.into(),
733 total_fragments: h.total_fragments,
734 min_fragments: h.min_fragments,
735 shape: h.shape.clone(),
736 original_size: h.original_size,
737 seed: h.seed,
738 }
739 }
740}
741
742impl HoloTensorHeaderPy {
743 fn to_rust_header(&self) -> PyResult<haagenti::HoloTensorHeader> {
744 Ok(haagenti::HoloTensorHeader {
745 encoding: self.encoding.into(),
746 compression: RustCompressionAlgorithm::Lz4,
747 flags: 0,
748 total_fragments: self.total_fragments,
749 min_fragments: self.min_fragments,
750 original_size: self.original_size,
751 seed: self.seed,
752 dtype: RustDType::F32,
753 shape: self.shape.clone(),
754 quality_curve: haagenti::QualityCurve::default(),
755 quantization: None,
756 })
757 }
758}
759
760#[pymethods]
761impl HoloTensorHeaderPy {
762 fn __repr__(&self) -> String {
763 format!(
764 "HoloTensorHeader(encoding={:?}, fragments={}, shape={:?})",
765 self.encoding, self.total_fragments, self.shape
766 )
767 }
768}
769
770#[pyclass]
779#[derive(Clone)]
780pub struct HoloFragmentPy {
781 #[pyo3(get)]
782 pub index: u16,
783 #[pyo3(get)]
784 pub flags: u16,
785 #[pyo3(get)]
786 pub checksum: u64,
787 data: Vec<u8>,
788}
789
790impl From<haagenti::HoloFragment> for HoloFragmentPy {
791 fn from(f: haagenti::HoloFragment) -> Self {
792 HoloFragmentPy {
793 index: f.index,
794 flags: f.flags,
795 checksum: f.checksum,
796 data: f.data,
797 }
798 }
799}
800
801impl HoloFragmentPy {
802 fn to_rust_fragment(&self) -> haagenti::HoloFragment {
803 haagenti::HoloFragment {
804 index: self.index,
805 flags: self.flags,
806 checksum: self.checksum,
807 data: self.data.clone(),
808 }
809 }
810}
811
812#[pymethods]
813impl HoloFragmentPy {
814 fn data<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<u8>> {
816 self.data.clone().into_pyarray_bound(py)
817 }
818
819 #[getter]
821 fn size(&self) -> usize {
822 self.data.len()
823 }
824
825 fn __repr__(&self) -> String {
826 format!(
827 "HoloFragment(index={}, size={})",
828 self.index,
829 self.data.len()
830 )
831 }
832}
833
834#[pyfunction]
840#[pyo3(signature = (input_path, output_path, algorithm=CompressionAlgorithm::Lz4))]
841fn convert_safetensors_to_hct(
842 input_path: &str,
843 output_path: &str,
844 algorithm: CompressionAlgorithm,
845) -> PyResult<(u64, u64, f64)> {
846 let data = std::fs::read(input_path)
848 .map_err(|e| PyIOError::new_err(format!("Failed to read {}: {}", input_path, e)))?;
849
850 let original_size = data.len() as u64;
853
854 let file = File::create(output_path)
856 .map_err(|e| PyIOError::new_err(format!("Failed to create {}: {}", output_path, e)))?;
857 let buf_writer = BufWriter::new(file);
858
859 let mut writer = RustHctWriterV2::new(
860 buf_writer,
861 algorithm.into(),
862 RustDType::F32, vec![data.len() as u64], );
865
866 match algorithm {
867 CompressionAlgorithm::Lz4 => {
868 let compressor = Lz4Compressor::new();
869 writer
870 .compress_data(&data, &compressor)
871 .map_err(|e| PyIOError::new_err(format!("Compression failed: {}", e)))?;
872 }
873 CompressionAlgorithm::Zstd => {
874 let compressor = ZstdCompressor::new();
875 writer
876 .compress_data(&data, &compressor)
877 .map_err(|e| PyIOError::new_err(format!("Compression failed: {}", e)))?;
878 }
879 }
880
881 writer
882 .finish()
883 .map_err(|e| PyIOError::new_err(format!("Failed to finalize: {}", e)))?;
884
885 let compressed_size = std::fs::metadata(output_path)
887 .map_err(|e| PyIOError::new_err(format!("Failed to stat {}: {}", output_path, e)))?
888 .len();
889
890 let ratio = original_size as f64 / compressed_size as f64;
891
892 Ok((original_size, compressed_size, ratio))
893}
894
895#[pyfunction]
897fn version() -> String {
898 env!("CARGO_PKG_VERSION").to_string()
899}
900
901#[pyfunction]
916#[pyo3(signature = (data, algorithm="zstd", level="default", dictionary=None))]
917fn compress(
918 data: &[u8],
919 algorithm: &str,
920 level: &str,
921 dictionary: Option<&ZstdDict>,
922) -> PyResult<Vec<u8>> {
923 let _ = dictionary; let compression_level = match level {
926 "fast" => haagenti_core::CompressionLevel::Fast,
927 "default" => haagenti_core::CompressionLevel::Default,
928 "best" => haagenti_core::CompressionLevel::Best,
929 _ => return Err(PyValueError::new_err(format!("Invalid level: {}", level))),
930 };
931
932 match algorithm.to_lowercase().as_str() {
933 "zstd" => {
934 let compressor = ZstdCompressor::with_level(compression_level);
935 compressor
936 .compress(data)
937 .map_err(|e| PyValueError::new_err(format!("Zstd compression failed: {}", e)))
938 }
939 "lz4" => {
940 let compressor = Lz4Compressor::new();
941 compressor
942 .compress(data)
943 .map_err(|e| PyValueError::new_err(format!("LZ4 compression failed: {}", e)))
944 }
945 _ => Err(PyValueError::new_err(format!(
946 "Invalid algorithm: {}. Use 'zstd' or 'lz4'",
947 algorithm
948 ))),
949 }
950}
951
952#[pyfunction]
961#[pyo3(signature = (data, algorithm="zstd"))]
962fn decompress(data: &[u8], algorithm: &str) -> PyResult<Vec<u8>> {
963 match algorithm.to_lowercase().as_str() {
964 "zstd" => {
965 let decompressor = ZstdDecompressor::new();
966 decompressor.decompress(data).map_err(|e| {
967 DecompressionError::new_err(format!("Zstd decompression failed: {}", e))
968 })
969 }
970 "lz4" => {
971 let decompressor = Lz4Decompressor::new();
972 decompressor.decompress(data).map_err(|e| {
973 DecompressionError::new_err(format!("LZ4 decompression failed: {}", e))
974 })
975 }
976 _ => Err(PyValueError::new_err(format!(
977 "Invalid algorithm: {}. Use 'zstd' or 'lz4'",
978 algorithm
979 ))),
980 }
981}
982
983#[pyclass]
989#[derive(Clone)]
990pub struct ZstdDict {
991 id: u32,
992 data: Vec<u8>,
993}
994
995#[pymethods]
996impl ZstdDict {
997 #[staticmethod]
1006 #[pyo3(signature = (samples, max_size=8192))]
1007 fn train(samples: Vec<Vec<u8>>, max_size: usize) -> PyResult<Self> {
1008 use haagenti_zstd::ZstdDictionary;
1009
1010 if samples.len() < 5 {
1011 return Err(PyValueError::new_err(
1012 "Need at least 5 samples for dictionary training",
1013 ));
1014 }
1015
1016 let sample_refs: Vec<&[u8]> = samples.iter().map(|s| s.as_slice()).collect();
1017 let dict = ZstdDictionary::train(&sample_refs, max_size)
1018 .map_err(|e| PyValueError::new_err(format!("Dictionary training failed: {}", e)))?;
1019
1020 Ok(ZstdDict {
1021 id: dict.id(),
1022 data: dict.serialize(),
1023 })
1024 }
1025
1026 #[getter]
1028 fn id(&self) -> u32 {
1029 self.id
1030 }
1031
1032 fn as_bytes(&self) -> Vec<u8> {
1034 self.data.clone()
1035 }
1036
1037 fn __repr__(&self) -> String {
1038 format!("ZstdDict(id={}, size={})", self.id, self.data.len())
1039 }
1040}
1041
1042#[pyclass]
1048pub struct StreamingEncoder {
1049 algorithm: String,
1050 buffer: Vec<u8>,
1051}
1052
1053#[pymethods]
1054impl StreamingEncoder {
1055 #[new]
1057 fn new(algorithm: &str) -> PyResult<Self> {
1058 match algorithm.to_lowercase().as_str() {
1059 "zstd" | "lz4" => Ok(StreamingEncoder {
1060 algorithm: algorithm.to_lowercase(),
1061 buffer: Vec::new(),
1062 }),
1063 _ => Err(PyValueError::new_err(format!(
1064 "Invalid algorithm: {}",
1065 algorithm
1066 ))),
1067 }
1068 }
1069
1070 fn write(&mut self, data: &[u8]) -> PyResult<()> {
1072 self.buffer.extend_from_slice(data);
1073 Ok(())
1074 }
1075
1076 fn finish(&mut self) -> PyResult<Vec<u8>> {
1078 let result = match self.algorithm.as_str() {
1079 "zstd" => {
1080 let compressor = ZstdCompressor::new();
1081 compressor
1082 .compress(&self.buffer)
1083 .map_err(|e| PyValueError::new_err(format!("Compression failed: {}", e)))?
1084 }
1085 "lz4" => {
1086 let compressor = Lz4Compressor::new();
1087 compressor
1088 .compress(&self.buffer)
1089 .map_err(|e| PyValueError::new_err(format!("Compression failed: {}", e)))?
1090 }
1091 _ => return Err(PyValueError::new_err("Invalid algorithm")),
1092 };
1093 self.buffer.clear();
1094 Ok(result)
1095 }
1096
1097 fn __enter__(slf: PyRef<Self>) -> PyRef<Self> {
1098 slf
1099 }
1100
1101 fn __exit__(
1102 &mut self,
1103 _exc_type: Option<PyObject>,
1104 _exc_val: Option<PyObject>,
1105 _exc_tb: Option<PyObject>,
1106 ) -> bool {
1107 false
1108 }
1109}
1110
1111#[pyclass]
1113pub struct StreamingDecoder {
1114 algorithm: String,
1115 buffer: Vec<u8>,
1116}
1117
1118#[pymethods]
1119impl StreamingDecoder {
1120 #[new]
1122 fn new(algorithm: &str) -> PyResult<Self> {
1123 match algorithm.to_lowercase().as_str() {
1124 "zstd" | "lz4" => Ok(StreamingDecoder {
1125 algorithm: algorithm.to_lowercase(),
1126 buffer: Vec::new(),
1127 }),
1128 _ => Err(PyValueError::new_err(format!(
1129 "Invalid algorithm: {}",
1130 algorithm
1131 ))),
1132 }
1133 }
1134
1135 fn write(&mut self, data: &[u8]) -> PyResult<Option<Vec<u8>>> {
1139 self.buffer.extend_from_slice(data);
1140 Ok(None) }
1142
1143 fn finish(&mut self) -> PyResult<Vec<u8>> {
1145 let result = match self.algorithm.as_str() {
1146 "zstd" => {
1147 let decompressor = ZstdDecompressor::new();
1148 decompressor.decompress(&self.buffer).map_err(|e| {
1149 DecompressionError::new_err(format!("Decompression failed: {}", e))
1150 })?
1151 }
1152 "lz4" => {
1153 let decompressor = Lz4Decompressor::new();
1154 decompressor.decompress(&self.buffer).map_err(|e| {
1155 DecompressionError::new_err(format!("Decompression failed: {}", e))
1156 })?
1157 }
1158 _ => return Err(PyValueError::new_err("Invalid algorithm")),
1159 };
1160 self.buffer.clear();
1161 Ok(result)
1162 }
1163
1164 fn __enter__(slf: PyRef<Self>) -> PyRef<Self> {
1165 slf
1166 }
1167
1168 fn __exit__(
1169 &mut self,
1170 _exc_type: Option<PyObject>,
1171 _exc_val: Option<PyObject>,
1172 _exc_tb: Option<PyObject>,
1173 ) -> bool {
1174 false
1175 }
1176}
1177
1178pyo3::create_exception!(haagenti, DecompressionError, pyo3::exceptions::PyException);
1183
1184fn bytes_to_f32(data: &[u8], dtype: RustDType) -> PyResult<Vec<f32>> {
1189 match dtype {
1190 RustDType::F32 => {
1191 if data.len() % 4 != 0 {
1192 return Err(PyValueError::new_err("Invalid F32 data length"));
1193 }
1194 Ok(data
1195 .chunks_exact(4)
1196 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
1197 .collect())
1198 }
1199 RustDType::F16 => {
1200 if data.len() % 2 != 0 {
1201 return Err(PyValueError::new_err("Invalid F16 data length"));
1202 }
1203 Ok(data
1204 .chunks_exact(2)
1205 .map(|b| {
1206 let bits = u16::from_le_bytes([b[0], b[1]]);
1207 half::f16::from_bits(bits).to_f32()
1208 })
1209 .collect())
1210 }
1211 RustDType::BF16 => {
1212 if data.len() % 2 != 0 {
1213 return Err(PyValueError::new_err("Invalid BF16 data length"));
1214 }
1215 Ok(data
1216 .chunks_exact(2)
1217 .map(|b| {
1218 let bits = u16::from_le_bytes([b[0], b[1]]);
1219 half::bf16::from_bits(bits).to_f32()
1220 })
1221 .collect())
1222 }
1223 RustDType::I8 => Ok(data.iter().map(|&b| b as i8 as f32).collect()),
1224 RustDType::I4 => {
1225 Ok(data
1227 .iter()
1228 .flat_map(|&b| {
1229 let lo = (b & 0x0F) as i8;
1230 let hi = ((b >> 4) & 0x0F) as i8;
1231 let lo = if lo & 0x08 != 0 {
1233 lo | 0xF0u8 as i8
1234 } else {
1235 lo
1236 };
1237 let hi = if hi & 0x08 != 0 {
1238 hi | 0xF0u8 as i8
1239 } else {
1240 hi
1241 };
1242 vec![lo as f32, hi as f32]
1243 })
1244 .collect())
1245 }
1246 }
1247}
1248
1249#[pymodule]
1255fn _haagenti_python(m: &Bound<'_, PyModule>) -> PyResult<()> {
1256 m.add_class::<CompressionAlgorithm>()?;
1258 m.add_class::<DType>()?;
1259 m.add_class::<QuantizationScheme>()?;
1260 m.add_class::<HolographicEncoding>()?;
1261
1262 m.add_class::<HctHeader>()?;
1264 m.add_class::<HctReader>()?;
1265 m.add_class::<HctWriter>()?;
1266
1267 m.add_class::<HoloTensorEncoder>()?;
1269 m.add_class::<HoloTensorDecoder>()?;
1270 m.add_class::<HoloTensorHeaderPy>()?;
1271 m.add_class::<HoloFragmentPy>()?;
1272
1273 m.add_class::<ZstdDict>()?;
1275 m.add_class::<StreamingEncoder>()?;
1276 m.add_class::<StreamingDecoder>()?;
1277
1278 m.add_function(wrap_pyfunction!(convert_safetensors_to_hct, m)?)?;
1280 m.add_function(wrap_pyfunction!(version, m)?)?;
1281 m.add_function(wrap_pyfunction!(compress, m)?)?;
1282 m.add_function(wrap_pyfunction!(decompress, m)?)?;
1283
1284 m.add(
1286 "DecompressionError",
1287 m.py().get_type_bound::<DecompressionError>(),
1288 )?;
1289
1290 let streaming = PyModule::new_bound(m.py(), "streaming")?;
1292 streaming.add_class::<StreamingEncoder>()?;
1293 streaming.add_class::<StreamingDecoder>()?;
1294 m.add_submodule(&streaming)?;
1295
1296 Ok(())
1297}