nexar/compression/
randomk.rs1use crate::types::DataType;
8
9use super::traits::{CompressedTensor, Compressor};
10
11pub struct RandomKCompressor {
16 ratio: f64,
18}
19
20impl RandomKCompressor {
21 pub fn new(ratio: f64) -> Self {
22 assert!(ratio > 0.0 && ratio <= 1.0, "ratio must be in (0.0, 1.0]");
23 Self { ratio }
24 }
25}
26
27impl Compressor for RandomKCompressor {
28 fn compress(
29 &self,
30 input: &[u8],
31 count: usize,
32 dtype: DataType,
33 residual: &mut [u8],
34 ) -> CompressedTensor {
35 let elem_size = dtype.size_in_bytes();
36 let k = ((count as f64 * self.ratio).ceil() as usize)
37 .max(1)
38 .min(count);
39
40 add_bytes(residual, input, count, elem_size);
42
43 let seed = residual_hash(residual);
47 let indices = sample_indices(count, k, seed);
48
49 let mut values = vec![0u8; k * elem_size];
51 for (i, &idx) in indices.iter().enumerate() {
52 let src_off = idx as usize * elem_size;
53 let dst_off = i * elem_size;
54 values[dst_off..dst_off + elem_size]
55 .copy_from_slice(&residual[src_off..src_off + elem_size]);
56 }
57
58 for &idx in &indices {
60 let off = idx as usize * elem_size;
61 for b in &mut residual[off..off + elem_size] {
62 *b = 0;
63 }
64 }
65
66 CompressedTensor::encode(&indices, &values, count, dtype)
67 }
68
69 fn decompress(&self, compressed: &CompressedTensor, output: &mut [u8]) {
70 let k = compressed.k();
71 let elem_size = compressed.dtype.size_in_bytes();
72 let indices = compressed.decode_indices();
73 let values = compressed.values_bytes();
74
75 for (i, &idx) in indices.iter().enumerate().take(k) {
76 let src_off = i * elem_size;
77 let dst_off = idx as usize * elem_size;
78 output[dst_off..dst_off + elem_size]
79 .copy_from_slice(&values[src_off..src_off + elem_size]);
80 }
81 }
82}
83
84fn add_bytes(residual: &mut [u8], input: &[u8], count: usize, elem_size: usize) {
86 match elem_size {
88 4 => {
89 for i in 0..count {
90 let off = i * 4;
91 let r = f32::from_le_bytes([
92 residual[off],
93 residual[off + 1],
94 residual[off + 2],
95 residual[off + 3],
96 ]);
97 let v = f32::from_le_bytes([
98 input[off],
99 input[off + 1],
100 input[off + 2],
101 input[off + 3],
102 ]);
103 residual[off..off + 4].copy_from_slice(&(r + v).to_le_bytes());
104 }
105 }
106 8 => {
107 for i in 0..count {
108 let off = i * 8;
109 let r = f64::from_le_bytes([
110 residual[off],
111 residual[off + 1],
112 residual[off + 2],
113 residual[off + 3],
114 residual[off + 4],
115 residual[off + 5],
116 residual[off + 6],
117 residual[off + 7],
118 ]);
119 let v = f64::from_le_bytes([
120 input[off],
121 input[off + 1],
122 input[off + 2],
123 input[off + 3],
124 input[off + 4],
125 input[off + 5],
126 input[off + 6],
127 input[off + 7],
128 ]);
129 residual[off..off + 8].copy_from_slice(&(r + v).to_le_bytes());
130 }
131 }
132 _ => {
133 for i in 0..residual.len().min(input.len()) {
135 residual[i] = residual[i].wrapping_add(input[i]);
136 }
137 }
138 }
139}
140
141fn residual_hash(residual: &[u8]) -> u64 {
143 let mut h: u64 = 0xcbf29ce484222325;
144 for &b in residual.iter().step_by(64).take(256) {
145 h ^= b as u64;
146 h = h.wrapping_mul(0x100000001b3);
147 }
148 h
149}
150
151fn sample_indices(n: usize, k: usize, seed: u64) -> Vec<u32> {
153 let mut state = seed;
154 let mut pool: Vec<u32> = (0..n as u32).collect();
155 for i in 0..k {
156 state = state
158 .wrapping_mul(6364136223846793005)
159 .wrapping_add(1442695040888963407);
160 let j = i + (state as usize % (n - i));
161 pool.swap(i, j);
162 }
163 let mut selected = pool[..k].to_vec();
164 selected.sort_unstable();
165 selected
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 #[test]
173 fn test_randomk_compress_f32() {
174 let compressor = RandomKCompressor::new(0.5);
175 let input: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
176 let input_bytes: &[u8] =
177 unsafe { std::slice::from_raw_parts(input.as_ptr() as *const u8, input.len() * 4) };
178 let mut residual = vec![0u8; 16];
179
180 let ct = compressor.compress(input_bytes, 4, DataType::F32, &mut residual);
181 assert_eq!(ct.k(), 2);
182 assert_eq!(ct.original_count, 4);
183
184 let mut output = vec![0u8; 16];
186 compressor.decompress(&ct, &mut output);
187
188 let out_f32: &[f32] =
190 unsafe { std::slice::from_raw_parts(output.as_ptr() as *const f32, 4) };
191 let nonzero_count = out_f32.iter().filter(|&&v| v != 0.0).count();
192 assert_eq!(nonzero_count, 2);
193 }
194
195 #[test]
196 fn test_sample_indices_unique() {
197 let indices = sample_indices(100, 10, 42);
198 assert_eq!(indices.len(), 10);
199 let mut sorted = indices.clone();
201 sorted.dedup();
202 assert_eq!(sorted.len(), 10);
203 for &idx in &indices {
205 assert!(idx < 100);
206 }
207 }
208}