Skip to main content

nexar/compression/
randomk.rs

1//! Random-K sampling: randomly select K% of elements.
2//!
3//! Simpler than TopK (no sort needed), and unbiased in expectation when
4//! combined with error feedback. Good for very large tensors where TopK's
5//! O(n log n) sort is too expensive.
6
7use crate::types::DataType;
8
9use super::traits::{CompressedTensor, Compressor};
10
11/// Random-K gradient compressor.
12///
13/// Randomly samples `ratio` fraction of elements. Combined with error
14/// feedback (residual accumulation), this is unbiased in expectation.
15pub struct RandomKCompressor {
16    /// Fraction of elements to keep (0.0, 1.0].
17    ratio: f64,
18}
19
20impl RandomKCompressor {
21    pub fn new(ratio: f64) -> Self {
22        assert!(ratio > 0.0 && ratio <= 1.0, "ratio must be in (0.0, 1.0]");
23        Self { ratio }
24    }
25}
26
27impl Compressor for RandomKCompressor {
28    fn compress(
29        &self,
30        input: &[u8],
31        count: usize,
32        dtype: DataType,
33        residual: &mut [u8],
34    ) -> CompressedTensor {
35        let elem_size = dtype.size_in_bytes();
36        let k = ((count as f64 * self.ratio).ceil() as usize)
37            .max(1)
38            .min(count);
39
40        // Add input to residual (error feedback).
41        add_bytes(residual, input, count, elem_size);
42
43        // Deterministic pseudo-random selection using a simple LCG seeded from residual content.
44        // This avoids pulling in a rand dependency. The seed changes each call because the
45        // residual content changes.
46        let seed = residual_hash(residual);
47        let indices = sample_indices(count, k, seed);
48
49        // Extract values from residual at selected indices.
50        let mut values = vec![0u8; k * elem_size];
51        for (i, &idx) in indices.iter().enumerate() {
52            let src_off = idx as usize * elem_size;
53            let dst_off = i * elem_size;
54            values[dst_off..dst_off + elem_size]
55                .copy_from_slice(&residual[src_off..src_off + elem_size]);
56        }
57
58        // Zero selected positions in residual.
59        for &idx in &indices {
60            let off = idx as usize * elem_size;
61            for b in &mut residual[off..off + elem_size] {
62                *b = 0;
63            }
64        }
65
66        CompressedTensor::encode(&indices, &values, count, dtype)
67    }
68
69    fn decompress(&self, compressed: &CompressedTensor, output: &mut [u8]) {
70        let k = compressed.k();
71        let elem_size = compressed.dtype.size_in_bytes();
72        let indices = compressed.decode_indices();
73        let values = compressed.values_bytes();
74
75        for (i, &idx) in indices.iter().enumerate().take(k) {
76            let src_off = i * elem_size;
77            let dst_off = idx as usize * elem_size;
78            output[dst_off..dst_off + elem_size]
79                .copy_from_slice(&values[src_off..src_off + elem_size]);
80        }
81    }
82}
83
84/// Element-wise add input bytes to residual, treating as native-endian values.
85fn add_bytes(residual: &mut [u8], input: &[u8], count: usize, elem_size: usize) {
86    // For simplicity, add as f32/f64 when possible, else byte-wise XOR (lossy but functional).
87    match elem_size {
88        4 => {
89            for i in 0..count {
90                let off = i * 4;
91                let r = f32::from_le_bytes([
92                    residual[off],
93                    residual[off + 1],
94                    residual[off + 2],
95                    residual[off + 3],
96                ]);
97                let v = f32::from_le_bytes([
98                    input[off],
99                    input[off + 1],
100                    input[off + 2],
101                    input[off + 3],
102                ]);
103                residual[off..off + 4].copy_from_slice(&(r + v).to_le_bytes());
104            }
105        }
106        8 => {
107            for i in 0..count {
108                let off = i * 8;
109                let r = f64::from_le_bytes([
110                    residual[off],
111                    residual[off + 1],
112                    residual[off + 2],
113                    residual[off + 3],
114                    residual[off + 4],
115                    residual[off + 5],
116                    residual[off + 6],
117                    residual[off + 7],
118                ]);
119                let v = f64::from_le_bytes([
120                    input[off],
121                    input[off + 1],
122                    input[off + 2],
123                    input[off + 3],
124                    input[off + 4],
125                    input[off + 5],
126                    input[off + 6],
127                    input[off + 7],
128                ]);
129                residual[off..off + 8].copy_from_slice(&(r + v).to_le_bytes());
130            }
131        }
132        _ => {
133            // Fallback: treat as integer bytes and add with wrapping.
134            for i in 0..residual.len().min(input.len()) {
135                residual[i] = residual[i].wrapping_add(input[i]);
136            }
137        }
138    }
139}
140
141/// Simple hash of residual content for seeding the sampler.
142fn residual_hash(residual: &[u8]) -> u64 {
143    let mut h: u64 = 0xcbf29ce484222325;
144    for &b in residual.iter().step_by(64).take(256) {
145        h ^= b as u64;
146        h = h.wrapping_mul(0x100000001b3);
147    }
148    h
149}
150
151/// Fisher-Yates partial shuffle to select k unique indices from [0, n).
152fn sample_indices(n: usize, k: usize, seed: u64) -> Vec<u32> {
153    let mut state = seed;
154    let mut pool: Vec<u32> = (0..n as u32).collect();
155    for i in 0..k {
156        // LCG step.
157        state = state
158            .wrapping_mul(6364136223846793005)
159            .wrapping_add(1442695040888963407);
160        let j = i + (state as usize % (n - i));
161        pool.swap(i, j);
162    }
163    let mut selected = pool[..k].to_vec();
164    selected.sort_unstable();
165    selected
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    #[test]
173    fn test_randomk_compress_f32() {
174        let compressor = RandomKCompressor::new(0.5);
175        let input: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
176        let input_bytes: &[u8] =
177            unsafe { std::slice::from_raw_parts(input.as_ptr() as *const u8, input.len() * 4) };
178        let mut residual = vec![0u8; 16];
179
180        let ct = compressor.compress(input_bytes, 4, DataType::F32, &mut residual);
181        assert_eq!(ct.k(), 2);
182        assert_eq!(ct.original_count, 4);
183
184        // Decompress.
185        let mut output = vec![0u8; 16];
186        compressor.decompress(&ct, &mut output);
187
188        // Exactly 2 values should be non-zero.
189        let out_f32: &[f32] =
190            unsafe { std::slice::from_raw_parts(output.as_ptr() as *const f32, 4) };
191        let nonzero_count = out_f32.iter().filter(|&&v| v != 0.0).count();
192        assert_eq!(nonzero_count, 2);
193    }
194
195    #[test]
196    fn test_sample_indices_unique() {
197        let indices = sample_indices(100, 10, 42);
198        assert_eq!(indices.len(), 10);
199        // All unique.
200        let mut sorted = indices.clone();
201        sorted.dedup();
202        assert_eq!(sorted.len(), 10);
203        // All in range.
204        for &idx in &indices {
205            assert!(idx < 100);
206        }
207    }
208}