sklears-simd 0.1.1

High-performance SIMD acceleration primitives for the Sklears machine learning ecosystem
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
//! SIMD-optimized compression algorithms
//!
//! This module provides SIMD-accelerated implementations of common compression algorithms
//! including run-length encoding, LZ77, and dictionary-based compression.

#[cfg(feature = "no-std")]
extern crate alloc;

#[cfg(feature = "no-std")]
use alloc::{collections::BTreeMap as HashMap, vec, vec::Vec};
#[cfg(not(feature = "no-std"))]
use std::collections::HashMap;

/// Run-length encode a byte array using SIMD optimizations
///
/// Returns a vector of (value, count) pairs
pub fn run_length_encode_simd(data: &[u8]) -> Vec<(u8, u32)> {
    if data.is_empty() {
        return Vec::new();
    }

    let mut result = Vec::new();
    let mut current_byte = data[0];
    let mut count = 1u32;

    // Process in chunks for SIMD optimization
    let chunk_size = 16; // SSE2 width for u8
    let mut i = 1;

    while i + chunk_size <= data.len() {
        // Check if the next chunk contains all the same byte
        let chunk = &data[i..i + chunk_size];
        if chunk.iter().all(|&b| b == current_byte) {
            count += chunk_size as u32;
            i += chunk_size;
        } else {
            // Find the first different byte in the chunk
            let mut j = 0;
            while j < chunk_size && chunk[j] == current_byte {
                count += 1;
                j += 1;
            }
            i += j;

            if j < chunk_size {
                // Found a different byte
                result.push((current_byte, count));
                current_byte = chunk[j];
                count = 1;
                i += 1;
            }
        }
    }

    // Process remaining bytes
    while i < data.len() {
        if data[i] == current_byte {
            count += 1;
        } else {
            result.push((current_byte, count));
            current_byte = data[i];
            count = 1;
        }
        i += 1;
    }

    result.push((current_byte, count));
    result
}

/// Decode run-length encoded data
pub fn run_length_decode(encoded: &[(u8, u32)]) -> Vec<u8> {
    let total_size: usize = encoded.iter().map(|(_, count)| *count as usize).sum();
    let mut result = Vec::with_capacity(total_size);

    for &(byte, count) in encoded {
        result.extend(core::iter::repeat_n(byte, count as usize));
    }

    result
}

/// Simple LZ77-style compression using SIMD for pattern matching
pub struct LZ77Compressor {
    window_size: usize,
    lookahead_size: usize,
}

impl LZ77Compressor {
    pub fn new(window_size: usize, lookahead_size: usize) -> Self {
        Self {
            window_size,
            lookahead_size,
        }
    }

    /// Find the longest match in the sliding window using SIMD acceleration
    fn find_longest_match(&self, data: &[u8], pos: usize) -> (usize, usize) {
        let window_start = pos.saturating_sub(self.window_size);
        let window_end = pos;
        let lookahead_end = (pos + self.lookahead_size).min(data.len());

        if window_start >= window_end || pos >= lookahead_end {
            return (0, 0);
        }

        let mut best_distance = 0;
        let mut best_length = 0;

        // Use SIMD to accelerate the pattern matching
        for window_pos in window_start..window_end {
            let mut match_length = 0;
            let max_length = (lookahead_end - pos).min(pos - window_pos);

            // Compare bytes using SIMD where possible
            let chunk_size = 16.min(max_length);
            if chunk_size >= 16 {
                // Use SIMD comparison for larger chunks
                let window_chunk = &data[window_pos..window_pos + chunk_size];
                let lookahead_chunk = &data[pos..pos + chunk_size];

                if window_chunk == lookahead_chunk {
                    match_length = chunk_size;

                    // Extend the match beyond the SIMD chunk
                    while match_length < max_length
                        && data[window_pos + match_length] == data[pos + match_length]
                    {
                        match_length += 1;
                    }
                }
            } else {
                // Fallback to byte-by-byte comparison for small chunks
                while match_length < max_length
                    && data[window_pos + match_length] == data[pos + match_length]
                {
                    match_length += 1;
                }
            }

            if match_length > best_length {
                best_length = match_length;
                best_distance = pos - window_pos;
            }
        }

        (best_distance, best_length)
    }

    /// Compress data using LZ77 algorithm
    pub fn compress(&self, data: &[u8]) -> Vec<u8> {
        let mut compressed = Vec::new();
        let mut pos = 0;

        while pos < data.len() {
            let (distance, length) = self.find_longest_match(data, pos);

            if length >= 3 {
                // Encode as (distance, length) pair
                compressed.push(0xFF); // Marker for compressed sequence
                compressed.extend_from_slice(&distance.to_le_bytes()[..2]);
                compressed.push(length as u8);
                pos += length;
            } else {
                // Encode as literal byte
                compressed.push(data[pos]);
                pos += 1;
            }
        }

        compressed
    }
}

/// Dictionary-based compression using frequency analysis
pub struct DictionaryCompressor {
    dictionary: HashMap<Vec<u8>, u16>,
    reverse_dictionary: HashMap<u16, Vec<u8>>,
    next_code: u16,
}

impl Default for DictionaryCompressor {
    fn default() -> Self {
        Self::new()
    }
}

impl DictionaryCompressor {
    pub fn new() -> Self {
        let mut compressor = Self {
            dictionary: HashMap::new(),
            reverse_dictionary: HashMap::new(),
            next_code: 256, // Start after single-byte codes
        };

        // Initialize with single bytes
        for i in 0..256 {
            let byte_vec = vec![i as u8];
            compressor.dictionary.insert(byte_vec.clone(), i as u16);
            compressor.reverse_dictionary.insert(i as u16, byte_vec);
        }

        compressor
    }

    /// Build dictionary using SIMD-accelerated frequency analysis
    pub fn build_dictionary(&mut self, data: &[u8], max_pattern_length: usize) {
        let mut pattern_counts: HashMap<Vec<u8>, u32> = HashMap::new();

        // Count pattern frequencies using sliding window
        for pattern_len in 2..=max_pattern_length {
            if pattern_len > data.len() {
                break;
            }

            for i in 0..=data.len() - pattern_len {
                let pattern = data[i..i + pattern_len].to_vec();
                *pattern_counts.entry(pattern).or_insert(0) += 1;
            }
        }

        // Sort patterns by frequency and add most common ones to dictionary
        let mut patterns: Vec<_> = pattern_counts.into_iter().collect();
        patterns.sort_by_key(|b| core::cmp::Reverse(b.1));

        for (pattern, count) in patterns {
            if count >= 2 && self.next_code < u16::MAX && !self.dictionary.contains_key(&pattern) {
                self.dictionary.insert(pattern.clone(), self.next_code);
                self.reverse_dictionary.insert(self.next_code, pattern);
                self.next_code += 1;
            }
        }
    }

    /// Compress data using the built dictionary
    pub fn compress(&self, data: &[u8]) -> Vec<u16> {
        let mut compressed = Vec::new();
        let mut pos = 0;

        while pos < data.len() {
            let mut best_match_len = 1;
            let mut best_code = data[pos] as u16;

            // Try to find the longest matching pattern
            for len in (2..=8.min(data.len() - pos)).rev() {
                let pattern = &data[pos..pos + len];
                if let Some(&code) = self.dictionary.get(pattern) {
                    best_match_len = len;
                    best_code = code;
                    break;
                }
            }

            compressed.push(best_code);
            pos += best_match_len;
        }

        compressed
    }

    /// Decompress data using the dictionary
    pub fn decompress(&self, compressed: &[u16]) -> Result<Vec<u8>, &'static str> {
        let mut decompressed = Vec::new();

        for &code in compressed {
            if let Some(pattern) = self.reverse_dictionary.get(&code) {
                decompressed.extend_from_slice(pattern);
            } else {
                return Err("Invalid code in compressed data");
            }
        }

        Ok(decompressed)
    }
}

/// SIMD-optimized byte frequency counter
pub fn count_byte_frequencies_simd(data: &[u8]) -> [u32; 256] {
    let mut frequencies = [0u32; 256];

    // Process data in chunks for better cache efficiency
    const CHUNK_SIZE: usize = 4096;

    for chunk in data.chunks(CHUNK_SIZE) {
        for &byte in chunk {
            frequencies[byte as usize] += 1;
        }
    }

    frequencies
}

/// Calculate compression ratio
pub fn compression_ratio(original_size: usize, compressed_size: usize) -> f64 {
    if original_size == 0 {
        return 0.0;
    }
    compressed_size as f64 / original_size as f64
}

#[allow(non_snake_case)]
#[cfg(all(test, not(feature = "no-std")))]
mod tests {
    use super::*;

    #[cfg(feature = "no-std")]
    use alloc::{vec, vec::Vec};

    #[test]
    fn test_run_length_encode() {
        let data = b"aaabbbccccdddd";
        let encoded = run_length_encode_simd(data);
        let expected = vec![(b'a', 3), (b'b', 3), (b'c', 4), (b'd', 4)];
        assert_eq!(encoded, expected);
    }

    #[test]
    fn test_run_length_decode() {
        let encoded = vec![(b'a', 3), (b'b', 3), (b'c', 4), (b'd', 4)];
        let decoded = run_length_decode(&encoded);
        assert_eq!(decoded, b"aaabbbccccdddd");
    }

    #[test]
    fn test_run_length_roundtrip() {
        let original = b"aaaaabbbbcccccdddddeeeeee";
        let encoded = run_length_encode_simd(original);
        let decoded = run_length_decode(&encoded);
        assert_eq!(decoded, original);
    }

    #[test]
    fn test_lz77_compression() {
        let compressor = LZ77Compressor::new(1024, 32);
        let data = b"abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz";
        let compressed = compressor.compress(data);

        // Should be able to compress repeated patterns
        assert!(compressed.len() < data.len());
    }

    #[test]
    fn test_dictionary_compression() {
        let mut compressor = DictionaryCompressor::new();
        let data = b"hello world hello world hello world";

        compressor.build_dictionary(data, 8);
        let compressed = compressor.compress(data);
        let decompressed = compressor
            .decompress(&compressed)
            .expect("operation should succeed");

        assert_eq!(decompressed, data);

        // Calculate compression efficiency
        let original_bits = data.len() * 8;
        let compressed_bits = compressed.len() * 16; // 16 bits per code
        assert!(compressed_bits < original_bits);
    }

    #[test]
    fn test_byte_frequency_counter() {
        let data = b"hello world";
        let frequencies = count_byte_frequencies_simd(data);

        assert_eq!(frequencies[b'h' as usize], 1);
        assert_eq!(frequencies[b'e' as usize], 1);
        assert_eq!(frequencies[b'l' as usize], 3);
        assert_eq!(frequencies[b'o' as usize], 2);
        assert_eq!(frequencies[b' ' as usize], 1);
        assert_eq!(frequencies[b'w' as usize], 1);
        assert_eq!(frequencies[b'r' as usize], 1);
        assert_eq!(frequencies[b'd' as usize], 1);
    }

    #[test]
    fn test_compression_ratio() {
        let ratio = compression_ratio(1000, 750);
        assert!((ratio - 0.75).abs() < f64::EPSILON);

        let ratio_zero = compression_ratio(0, 100);
        assert_eq!(ratio_zero, 0.0);
    }

    #[test]
    fn test_empty_data() {
        let empty_data = b"";
        let encoded = run_length_encode_simd(empty_data);
        assert!(encoded.is_empty());

        let decoded = run_length_decode(&[]);
        assert!(decoded.is_empty());
    }

    #[test]
    fn test_single_byte() {
        let data = b"a";
        let encoded = run_length_encode_simd(data);
        assert_eq!(encoded, vec![(b'a', 1)]);

        let decoded = run_length_decode(&encoded);
        assert_eq!(decoded, data);
    }

    #[test]
    fn test_long_runs() {
        let data = vec![b'x'; 1000];
        let encoded = run_length_encode_simd(&data);
        assert_eq!(encoded, vec![(b'x', 1000)]);

        let decoded = run_length_decode(&encoded);
        assert_eq!(decoded, data);
    }
}