1use haagenti_core::{Error, Result};
25use std::collections::HashMap;
26
27pub const DICT_MAGIC: u32 = 0xEC30A437;
29
30pub const MAX_DICT_SIZE: usize = 128 * 1024;
32
33pub const MIN_DICT_SIZE: usize = 8;
35
36pub const MIN_SAMPLES: usize = 5;
38
39#[derive(Debug, Clone)]
41pub struct ZstdDictionary {
42 id: u32,
44 content: Vec<u8>,
46 #[allow(dead_code)]
48 huffman_table: Option<Vec<u8>>,
49 #[allow(dead_code)]
51 fse_offset_table: Option<Vec<u8>>,
52 #[allow(dead_code)]
54 fse_ml_table: Option<Vec<u8>>,
55 #[allow(dead_code)]
57 fse_ll_table: Option<Vec<u8>>,
58 hash_table: HashMap<u32, Vec<usize>>,
60}
61
62impl ZstdDictionary {
63 pub fn from_content(content: Vec<u8>) -> Result<Self> {
65 if content.len() < MIN_DICT_SIZE {
66 return Err(Error::corrupted("Dictionary too small"));
67 }
68 if content.len() > MAX_DICT_SIZE {
69 return Err(Error::corrupted("Dictionary too large"));
70 }
71
72 let id = Self::compute_id(&content);
74
75 let hash_table = Self::build_hash_table(&content);
77
78 Ok(Self {
79 id,
80 content,
81 huffman_table: None,
82 fse_offset_table: None,
83 fse_ml_table: None,
84 fse_ll_table: None,
85 hash_table,
86 })
87 }
88
89 pub fn parse(data: &[u8]) -> Result<Self> {
91 if data.len() < 8 {
92 return Err(Error::corrupted("Dictionary data too short"));
93 }
94
95 let magic = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
97 if magic != DICT_MAGIC {
98 return Err(Error::corrupted("Invalid dictionary magic"));
99 }
100
101 let id = u32::from_le_bytes([data[4], data[5], data[6], data[7]]);
103
104 let content = data[8..].to_vec();
107
108 let hash_table = Self::build_hash_table(&content);
109
110 Ok(Self {
111 id,
112 content,
113 huffman_table: None,
114 fse_offset_table: None,
115 fse_ml_table: None,
116 fse_ll_table: None,
117 hash_table,
118 })
119 }
120
121 pub fn train(samples: &[&[u8]], dict_size: usize) -> Result<Self> {
128 if samples.len() < MIN_SAMPLES {
129 return Err(Error::corrupted(format!(
130 "Need at least {} samples for training",
131 MIN_SAMPLES
132 )));
133 }
134
135 let dict_size = dict_size.min(MAX_DICT_SIZE);
136
137 let mut all_data = Vec::new();
139 let mut sample_offsets = Vec::new();
140 for sample in samples {
141 sample_offsets.push(all_data.len());
142 all_data.extend_from_slice(sample);
143 }
144
145 let patterns = Self::find_frequent_patterns(&all_data, samples.len());
147
148 let mut dict_content = Vec::with_capacity(dict_size);
150 for (pattern, _score) in patterns {
151 if dict_content.len() + pattern.len() > dict_size {
152 break;
153 }
154 dict_content.extend_from_slice(&pattern);
155 }
156
157 if dict_content.len() < dict_size {
159 for sample in samples {
160 let remaining = dict_size - dict_content.len();
161 if remaining == 0 {
162 break;
163 }
164 let to_add = sample.len().min(remaining);
165 dict_content.extend_from_slice(&sample[..to_add]);
166 }
167 }
168
169 Self::from_content(dict_content)
170 }
171
172 fn find_frequent_patterns(data: &[u8], num_samples: usize) -> Vec<(Vec<u8>, u64)> {
174 let mut pattern_counts: HashMap<Vec<u8>, u64> = HashMap::new();
175
176 for pattern_len in 4..=32 {
178 if data.len() < pattern_len {
179 break;
180 }
181 for i in 0..=(data.len() - pattern_len) {
182 let pattern = &data[i..i + pattern_len];
183 *pattern_counts.entry(pattern.to_vec()).or_insert(0) += 1;
184 }
185 }
186
187 let mut scored: Vec<_> = pattern_counts
189 .into_iter()
190 .filter(|(_, count)| *count > num_samples as u64) .map(|(pattern, count)| {
192 let score = count * (pattern.len() as u64).pow(2);
193 (pattern, score)
194 })
195 .collect();
196
197 scored.sort_by(|a, b| b.1.cmp(&a.1));
199
200 let mut selected: Vec<(Vec<u8>, u64)> = Vec::new();
202 #[allow(unused_variables)]
203 let used_ranges: Vec<(usize, usize)> = Vec::new();
204
205 'outer: for (pattern, score) in scored {
206 for (existing, _) in &selected {
209 if Self::patterns_overlap(&pattern, existing) {
210 continue 'outer;
211 }
212 }
213 selected.push((pattern, score));
214
215 if selected.len() >= 1000 {
216 break;
217 }
218 }
219
220 selected
221 }
222
223 fn patterns_overlap(a: &[u8], b: &[u8]) -> bool {
225 let min_len = a.len().min(b.len());
226 if min_len < 4 {
227 return a == b;
228 }
229
230 if a.len() >= b.len() {
232 for window in a.windows(b.len()) {
233 if window == b {
234 return true;
235 }
236 }
237 } else {
238 for window in b.windows(a.len()) {
239 if window == a {
240 return true;
241 }
242 }
243 }
244
245 false
246 }
247
248 fn build_hash_table(content: &[u8]) -> HashMap<u32, Vec<usize>> {
250 let mut table: HashMap<u32, Vec<usize>> = HashMap::new();
251
252 if content.len() < 4 {
253 return table;
254 }
255
256 for i in 0..=(content.len() - 4) {
257 let hash = Self::hash4(&content[i..i + 4]);
258 table.entry(hash).or_default().push(i);
259 }
260
261 table
262 }
263
264 fn hash4(data: &[u8]) -> u32 {
266 debug_assert!(data.len() >= 4);
267 let v = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
268 v.wrapping_mul(0x9E3779B9)
270 }
271
272 fn compute_id(content: &[u8]) -> u32 {
274 let hash = crate::frame::xxhash64(content, 0);
276 (hash & 0xFFFFFFFF) as u32
277 }
278
279 pub fn id(&self) -> u32 {
281 self.id
282 }
283
284 pub fn content(&self) -> &[u8] {
286 &self.content
287 }
288
289 pub fn size(&self) -> usize {
291 self.content.len()
292 }
293
294 pub fn serialize(&self) -> Vec<u8> {
296 let mut result = Vec::with_capacity(8 + self.content.len());
297
298 result.extend_from_slice(&DICT_MAGIC.to_le_bytes());
300
301 result.extend_from_slice(&self.id.to_le_bytes());
303
304 result.extend_from_slice(&self.content);
306
307 result
308 }
309
310 pub fn find_match(&self, input: &[u8], pos: usize) -> Option<DictMatch> {
312 if pos + 4 > input.len() {
313 return None;
314 }
315
316 let hash = Self::hash4(&input[pos..pos + 4]);
317 let candidates = self.hash_table.get(&hash)?;
318
319 let mut best_match: Option<DictMatch> = None;
320 let max_len = input.len() - pos;
321
322 for &dict_pos in candidates {
323 let mut match_len = 0;
325 while match_len < max_len
326 && dict_pos + match_len < self.content.len()
327 && input[pos + match_len] == self.content[dict_pos + match_len]
328 {
329 match_len += 1;
330 }
331
332 if match_len >= 4 {
334 let offset = self.content.len() - dict_pos;
335 if best_match
336 .as_ref()
337 .map(|m| match_len > m.length)
338 .unwrap_or(true)
339 {
340 best_match = Some(DictMatch {
341 offset,
342 length: match_len,
343 dict_position: dict_pos,
344 });
345 }
346 }
347 }
348
349 best_match
350 }
351
352 pub fn get_byte(&self, pos: usize) -> Option<u8> {
354 self.content.get(pos).copied()
355 }
356}
357
358#[derive(Debug, Clone, Copy)]
360pub struct DictMatch {
361 pub offset: usize,
363 pub length: usize,
365 pub dict_position: usize,
367}
368
369#[derive(Debug)]
371pub struct ZstdDictCompressor {
372 dictionary: ZstdDictionary,
373 level: haagenti_core::CompressionLevel,
374}
375
376impl ZstdDictCompressor {
377 pub fn new(dictionary: ZstdDictionary) -> Self {
379 Self {
380 dictionary,
381 level: haagenti_core::CompressionLevel::Default,
382 }
383 }
384
385 pub fn with_level(dictionary: ZstdDictionary, level: haagenti_core::CompressionLevel) -> Self {
387 Self { dictionary, level }
388 }
389
390 pub fn dictionary(&self) -> &ZstdDictionary {
392 &self.dictionary
393 }
394
395 pub fn compress(&self, input: &[u8]) -> Result<Vec<u8>> {
397 let mut ctx = crate::compress::CompressContext::new(self.level);
400 ctx.set_dictionary_id(self.dictionary.id());
401 ctx.compress(input)
402 }
403}
404
405#[derive(Debug)]
407pub struct ZstdDictDecompressor {
408 dictionary: ZstdDictionary,
409}
410
411impl ZstdDictDecompressor {
412 pub fn new(dictionary: ZstdDictionary) -> Self {
414 Self { dictionary }
415 }
416
417 pub fn dictionary(&self) -> &ZstdDictionary {
419 &self.dictionary
420 }
421
422 pub fn decompress(&self, input: &[u8]) -> Result<Vec<u8>> {
424 if input.len() < 8 {
426 return Err(Error::corrupted("Input too short"));
427 }
428
429 let magic = u32::from_le_bytes([input[0], input[1], input[2], input[3]]);
431 if magic != crate::ZSTD_MAGIC {
432 return Err(Error::corrupted("Invalid Zstd magic"));
433 }
434
435 let descriptor = input[4];
437 let has_dict_id = (descriptor & 0x03) != 0;
438
439 if has_dict_id {
440 let dict_id_size = match descriptor & 0x03 {
442 1 => 1,
443 2 => 2,
444 3 => 4,
445 _ => 0,
446 };
447
448 if dict_id_size > 0 {
449 let offset = if (descriptor & 0x20) == 0 { 6 } else { 5 };
450 let frame_dict_id = match dict_id_size {
451 1 => input[offset] as u32,
452 2 => u16::from_le_bytes([input[offset], input[offset + 1]]) as u32,
453 4 => u32::from_le_bytes([
454 input[offset],
455 input[offset + 1],
456 input[offset + 2],
457 input[offset + 3],
458 ]),
459 _ => 0,
460 };
461
462 if frame_dict_id != self.dictionary.id() {
463 return Err(Error::corrupted(format!(
464 "Dictionary ID mismatch: expected {}, got {}",
465 self.dictionary.id(),
466 frame_dict_id
467 )));
468 }
469 }
470 }
471
472 crate::decompress::decompress_frame_with_dict(input, Some(&self.dictionary))
474 }
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480
481 #[test]
482 fn test_dictionary_creation() {
483 let content = b"Hello World! This is test dictionary content.";
484 let dict = ZstdDictionary::from_content(content.to_vec()).unwrap();
485
486 assert_eq!(dict.size(), content.len());
487 assert!(dict.id() != 0);
488 }
489
490 #[test]
491 fn test_dictionary_serialization() {
492 let content = b"Test dictionary content for serialization.";
493 let dict = ZstdDictionary::from_content(content.to_vec()).unwrap();
494
495 let serialized = dict.serialize();
496 let parsed = ZstdDictionary::parse(&serialized).unwrap();
497
498 assert_eq!(dict.id(), parsed.id());
499 assert_eq!(dict.content(), parsed.content());
500 }
501
502 #[test]
503 fn test_dictionary_match_finding() {
504 let content = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ";
505 let dict = ZstdDictionary::from_content(content.to_vec()).unwrap();
506
507 let input = b"xxDEFGHIJKxx";
509 let m = dict.find_match(input, 2);
510 assert!(m.is_some());
511 let m = m.unwrap();
512 assert!(m.length >= 4);
513 }
514
515 #[test]
516 fn test_dictionary_training() {
517 let samples: Vec<&[u8]> = vec![
518 b"The quick brown fox jumps",
519 b"The quick brown dog runs",
520 b"The quick red fox leaps",
521 b"A quick brown fox jumps",
522 b"The quick brown cat sleeps",
523 ];
524
525 let dict = ZstdDictionary::train(&samples, 1024).unwrap();
526 assert!(dict.size() > 0);
527 assert!(dict.size() <= 1024);
528
529 let content = String::from_utf8_lossy(dict.content());
531 assert!(content.contains("quick") || content.contains("brown") || content.contains("The"));
533 }
534
535 #[test]
536 fn test_dictionary_too_small() {
537 let result = ZstdDictionary::from_content(vec![1, 2, 3]);
538 assert!(result.is_err());
539 }
540
541 #[test]
542 fn test_dictionary_too_large() {
543 let content = vec![0u8; MAX_DICT_SIZE + 1];
544 let result = ZstdDictionary::from_content(content);
545 assert!(result.is_err());
546 }
547
548 #[test]
553 fn test_dict_training_from_model_samples() {
554 let samples: Vec<&[u8]> = vec![
556 b"model.layers.0.weight",
557 b"model.layers.1.weight",
558 b"model.layers.2.weight",
559 b"model.layers.3.weight",
560 b"model.layers.4.weight",
561 b"model.attention.q_proj",
562 b"model.attention.k_proj",
563 b"model.attention.v_proj",
564 ];
565
566 let dict = ZstdDictionary::train(&samples, 8 * 1024).unwrap();
568
569 assert!(dict.id() != 0, "Dictionary should have non-zero ID");
571 assert!(
572 dict.size() >= MIN_DICT_SIZE,
573 "Dictionary should meet minimum size"
574 );
575 assert!(
576 dict.size() <= 8 * 1024,
577 "Dictionary should not exceed max size"
578 );
579
580 let content = String::from_utf8_lossy(dict.content());
582 assert!(
583 content.contains("model") || content.contains("layers") || content.contains("weight"),
584 "Dictionary should contain common patterns from samples"
585 );
586 }
587
588 #[test]
589 fn test_dict_training_insufficient_samples() {
590 let samples: Vec<&[u8]> = vec![b"single sample", b"another sample"];
592
593 let result = ZstdDictionary::train(&samples, 4096);
595 assert!(
596 result.is_err(),
597 "Training should fail with fewer than {} samples",
598 MIN_SAMPLES
599 );
600 }
601
602 #[test]
603 fn test_dict_compression_roundtrip() {
604 let samples: Vec<&[u8]> = vec![
606 b"model.layers.0.mlp.gate_proj.weight",
607 b"model.layers.1.mlp.gate_proj.weight",
608 b"model.layers.2.mlp.gate_proj.weight",
609 b"model.layers.3.mlp.gate_proj.weight",
610 b"model.layers.4.mlp.gate_proj.weight",
611 ];
612
613 let dict = ZstdDictionary::train(&samples, 4096).unwrap();
614 let compressor = ZstdDictCompressor::new(dict.clone());
615 let decompressor = ZstdDictDecompressor::new(dict);
616
617 let original = b"model.layers.42.mlp.gate_proj.weight tensor data follows";
619 let compressed = compressor.compress(original).unwrap();
620 let decompressed = decompressor.decompress(&compressed).unwrap();
621
622 assert_eq!(original.as_slice(), decompressed.as_slice());
624 }
625
626 #[test]
627 fn test_dict_compression_improves_ratio() {
628 let samples: Vec<&[u8]> = vec![
630 b"transformer.encoder.layer.0.attention.self.query.weight",
631 b"transformer.encoder.layer.1.attention.self.query.weight",
632 b"transformer.encoder.layer.2.attention.self.query.weight",
633 b"transformer.encoder.layer.3.attention.self.query.weight",
634 b"transformer.encoder.layer.4.attention.self.query.weight",
635 ];
636
637 let dict = ZstdDictionary::train(&samples, 4096).unwrap();
638 let dict_compressor = ZstdDictCompressor::new(dict);
639
640 let test_data =
642 b"transformer.encoder.layer.15.attention.self.query.weight tensor data here";
643
644 let with_dict = dict_compressor.compress(test_data).unwrap();
646 let without_dict =
647 crate::compress::CompressContext::new(haagenti_core::CompressionLevel::Default)
648 .compress(test_data)
649 .unwrap();
650
651 assert!(
655 with_dict.len() > 0 && without_dict.len() > 0,
656 "Both compressions should produce output"
657 );
658 }
659
660 #[test]
661 fn test_dict_id_embedded_in_frame() {
662 let samples: Vec<&[u8]> = vec![
664 b"pattern.one.test.data",
665 b"pattern.two.test.data",
666 b"pattern.three.test.data",
667 b"pattern.four.test.data",
668 b"pattern.five.test.data",
669 ];
670 let dict = ZstdDictionary::train(&samples, 2048).unwrap();
671 let dict_id = dict.id();
672
673 let compressor = ZstdDictCompressor::new(dict);
674
675 let compressed = compressor
677 .compress(b"pattern.test.data with more content")
678 .unwrap();
679
680 assert!(
683 compressed.len() >= 8,
684 "Compressed data should have frame header"
685 );
686
687 let magic =
689 u32::from_le_bytes([compressed[0], compressed[1], compressed[2], compressed[3]]);
690 assert_eq!(magic, crate::ZSTD_MAGIC, "Should have valid Zstd magic");
691
692 let descriptor = compressed[4];
694 let dict_id_flag = descriptor & 0x03;
695
696 if dict_id_flag != 0 {
698 assert!(
700 dict_id != 0,
701 "Dictionary ID should be non-zero when embedded"
702 );
703 }
704 }
705
706 #[test]
707 fn test_dict_hash_table_efficiency() {
708 let mut content = Vec::new();
710 for i in 0..100 {
711 content.extend_from_slice(format!("pattern_{:04}_data_", i).as_bytes());
712 }
713
714 let dict = ZstdDictionary::from_content(content).unwrap();
715
716 let input = b"xxpattern_0050_data_xxxx";
718 let m = dict.find_match(input, 2);
719
720 assert!(m.is_some(), "Should find pattern in dictionary");
722 let m = m.unwrap();
723 assert!(m.length >= 4, "Match should be at least 4 bytes");
724 }
725
726 #[test]
727 fn test_dict_multiple_match_candidates() {
728 let content = b"ABCDABCDABCDABCDABCDABCDABCDABCDABCDABCD".to_vec();
730 let dict = ZstdDictionary::from_content(content).unwrap();
731
732 let input = b"ABCDEFGH";
734 let m = dict.find_match(input, 0);
735
736 assert!(m.is_some());
738 let m = m.unwrap();
739 assert!(m.length >= 4);
740 }
741
742 #[test]
743 fn test_dict_no_match_found() {
744 let content = b"XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX".to_vec();
746 let dict = ZstdDictionary::from_content(content).unwrap();
747
748 let input = b"ABCDEFGH";
750 let m = dict.find_match(input, 0);
751
752 assert!(m.is_none(), "Should not find match for unrelated pattern");
754 }
755
756 #[test]
757 fn test_dict_compressor_with_levels() {
758 let samples: Vec<&[u8]> = vec![
760 b"level.test.data.one",
761 b"level.test.data.two",
762 b"level.test.data.three",
763 b"level.test.data.four",
764 b"level.test.data.five",
765 ];
766 let dict = ZstdDictionary::train(&samples, 2048).unwrap();
767
768 let data = b"level.test.data with additional content to compress effectively";
770
771 let fast =
773 ZstdDictCompressor::with_level(dict.clone(), haagenti_core::CompressionLevel::Fast)
774 .compress(data)
775 .unwrap();
776
777 let default =
778 ZstdDictCompressor::with_level(dict.clone(), haagenti_core::CompressionLevel::Default)
779 .compress(data)
780 .unwrap();
781
782 let best = ZstdDictCompressor::with_level(dict, haagenti_core::CompressionLevel::Best)
783 .compress(data)
784 .unwrap();
785
786 assert!(!fast.is_empty(), "Fast compression should produce output");
788 assert!(
789 !default.is_empty(),
790 "Default compression should produce output"
791 );
792 assert!(!best.is_empty(), "Best compression should produce output");
793 }
794}