1#[cfg(feature = "no-std")]
7extern crate alloc;
8
9#[cfg(feature = "no-std")]
10use alloc::collections::BTreeMap as HashMap;
11#[cfg(not(feature = "no-std"))]
12use std::{collections::HashMap, vec};
13
14#[derive(Debug, Clone, Copy)]
16pub struct ErrorBound {
17 pub relative_error: f64,
18 pub absolute_error: f64,
19 pub probability: f64, }
21
22impl ErrorBound {
23 pub const TIGHT: Self = Self {
24 relative_error: 0.01, absolute_error: 1e-6,
26 probability: 0.99,
27 };
28
29 pub const MODERATE: Self = Self {
30 relative_error: 0.05, absolute_error: 1e-4,
32 probability: 0.95,
33 };
34
35 pub const RELAXED: Self = Self {
36 relative_error: 0.1, absolute_error: 1e-3,
38 probability: 0.9,
39 };
40}
41
42pub mod approximate_ops {
44 use super::*;
45
46 pub fn approximate_dot_product_f32(
48 a: &[f32],
49 b: &[f32],
50 error_bound: ErrorBound,
51 ) -> (f32, ErrorBound) {
52 assert_eq!(a.len(), b.len());
53
54 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
55 {
56 if crate::simd_feature_detected!("avx2") {
57 return unsafe { approximate_dot_product_f32_avx2(a, b, error_bound) };
58 }
59 }
60
61 approximate_dot_product_f32_scalar(a, b, error_bound)
62 }
63
64 pub fn approximate_sum_f32(data: &[f32], error_bound: ErrorBound) -> (f32, ErrorBound) {
66 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
67 {
68 if crate::simd_feature_detected!("avx2") {
69 return unsafe { approximate_sum_f32_avx2(data, error_bound) };
70 }
71 }
72
73 approximate_sum_f32_scalar(data, error_bound)
74 }
75
76 pub fn approximate_l2_norm_f32(data: &[f32], error_bound: ErrorBound) -> (f32, ErrorBound) {
78 let (sum_squares, error) = approximate_sum_of_squares_f32(data, error_bound);
79 let norm = sum_squares.sqrt();
80
81 let propagated_error = ErrorBound {
83 relative_error: error.relative_error * 0.5, absolute_error: error.absolute_error * 0.5,
85 probability: error.probability,
86 };
87
88 (norm, propagated_error)
89 }
90
91 pub fn approximate_sum_of_squares_f32(
93 data: &[f32],
94 error_bound: ErrorBound,
95 ) -> (f32, ErrorBound) {
96 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
97 {
98 if crate::simd_feature_detected!("avx2") {
99 return unsafe { approximate_sum_of_squares_f32_avx2(data, error_bound) };
100 }
101 }
102
103 approximate_sum_of_squares_f32_scalar(data, error_bound)
104 }
105
106 fn approximate_dot_product_f32_scalar(
108 a: &[f32],
109 b: &[f32],
110 error_bound: ErrorBound,
111 ) -> (f32, ErrorBound) {
112 let mut sum = 0.0f32;
114
115 for (&x, &y) in a.iter().zip(b.iter()) {
116 let x_approx = quantize_f32(x, 16); let y_approx = quantize_f32(y, 16);
119 sum += x_approx * y_approx;
120 }
121
122 let estimated_error = ErrorBound {
124 relative_error: (error_bound.relative_error + 0.001).min(0.1),
125 absolute_error: error_bound.absolute_error + 1e-5,
126 probability: error_bound.probability * 0.95,
127 };
128
129 (sum, estimated_error)
130 }
131
132 fn approximate_sum_f32_scalar(data: &[f32], error_bound: ErrorBound) -> (f32, ErrorBound) {
133 let mut sum = 0.0f32;
135 let mut c = 0.0f32; for &x in data {
138 let x_approx = quantize_f32(x, 16);
139 let y = x_approx - c;
140 let t = sum + y;
141 c = (t - sum) - y;
142 sum = t;
143 }
144
145 let estimated_error = ErrorBound {
146 relative_error: (error_bound.relative_error + 0.0005).min(0.05),
147 absolute_error: error_bound.absolute_error + 1e-6,
148 probability: error_bound.probability * 0.98,
149 };
150
151 (sum, estimated_error)
152 }
153
154 fn approximate_sum_of_squares_f32_scalar(
155 data: &[f32],
156 error_bound: ErrorBound,
157 ) -> (f32, ErrorBound) {
158 let mut sum = 0.0f32;
159
160 for &x in data {
161 let x_approx = quantize_f32(x, 16);
162 sum += x_approx * x_approx;
163 }
164
165 let estimated_error = ErrorBound {
166 relative_error: (error_bound.relative_error + 0.002).min(0.1),
167 absolute_error: error_bound.absolute_error + 1e-5,
168 probability: error_bound.probability * 0.95,
169 };
170
171 (sum, estimated_error)
172 }
173
174 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
176 #[target_feature(enable = "avx2")]
177 unsafe fn approximate_dot_product_f32_avx2(
178 a: &[f32],
179 b: &[f32],
180 error_bound: ErrorBound,
181 ) -> (f32, ErrorBound) {
182 use core::arch::x86_64::*;
183
184 let mut sum_vec = _mm256_setzero_ps();
185 let chunks_a = a.chunks_exact(8);
186 let chunks_b = b.chunks_exact(8);
187 let remainder_a = chunks_a.remainder();
188 let remainder_b = chunks_b.remainder();
189
190 for (chunk_a, chunk_b) in chunks_a.zip(chunks_b) {
191 let vec_a = _mm256_loadu_ps(chunk_a.as_ptr());
192 let vec_b = _mm256_loadu_ps(chunk_b.as_ptr());
193
194 sum_vec = _mm256_fmadd_ps(vec_a, vec_b, sum_vec);
196 }
197
198 let sum_high = _mm256_extractf128_ps(sum_vec, 1);
200 let sum_low = _mm256_castps256_ps128(sum_vec);
201 let sum128 = _mm_add_ps(sum_high, sum_low);
202 let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
203 let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
204 let mut result = _mm_cvtss_f32(sum32);
205
206 for (&x, &y) in remainder_a.iter().zip(remainder_b.iter()) {
208 result += x * y;
209 }
210
211 let estimated_error = ErrorBound {
212 relative_error: error_bound.relative_error * 0.8, absolute_error: error_bound.absolute_error,
214 probability: error_bound.probability * 0.99,
215 };
216
217 (result, estimated_error)
218 }
219
220 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
221 #[target_feature(enable = "avx2")]
222 unsafe fn approximate_sum_f32_avx2(data: &[f32], error_bound: ErrorBound) -> (f32, ErrorBound) {
223 use core::arch::x86_64::*;
224
225 let mut sum_vec = _mm256_setzero_ps();
226 let chunks = data.chunks_exact(8);
227 let remainder = chunks.remainder();
228
229 for chunk in chunks {
230 let vec = _mm256_loadu_ps(chunk.as_ptr());
231 sum_vec = _mm256_add_ps(sum_vec, vec);
232 }
233
234 let sum_high = _mm256_extractf128_ps(sum_vec, 1);
236 let sum_low = _mm256_castps256_ps128(sum_vec);
237 let sum128 = _mm_add_ps(sum_high, sum_low);
238 let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
239 let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
240 let mut result = _mm_cvtss_f32(sum32);
241
242 for &x in remainder {
244 result += x;
245 }
246
247 let estimated_error = ErrorBound {
248 relative_error: error_bound.relative_error * 0.9,
249 absolute_error: error_bound.absolute_error,
250 probability: error_bound.probability * 0.99,
251 };
252
253 (result, estimated_error)
254 }
255
256 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
257 #[target_feature(enable = "avx2")]
258 unsafe fn approximate_sum_of_squares_f32_avx2(
259 data: &[f32],
260 error_bound: ErrorBound,
261 ) -> (f32, ErrorBound) {
262 use core::arch::x86_64::*;
263
264 let mut sum_vec = _mm256_setzero_ps();
265 let chunks = data.chunks_exact(8);
266 let remainder = chunks.remainder();
267
268 for chunk in chunks {
269 let vec = _mm256_loadu_ps(chunk.as_ptr());
270 sum_vec = _mm256_fmadd_ps(vec, vec, sum_vec);
271 }
272
273 let sum_high = _mm256_extractf128_ps(sum_vec, 1);
275 let sum_low = _mm256_castps256_ps128(sum_vec);
276 let sum128 = _mm_add_ps(sum_high, sum_low);
277 let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
278 let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
279 let mut result = _mm_cvtss_f32(sum32);
280
281 for &x in remainder {
283 result += x * x;
284 }
285
286 let estimated_error = ErrorBound {
287 relative_error: error_bound.relative_error * 0.85,
288 absolute_error: error_bound.absolute_error,
289 probability: error_bound.probability * 0.99,
290 };
291
292 (result, estimated_error)
293 }
294
295 fn quantize_f32(value: f32, bits: u8) -> f32 {
297 if bits >= 32 {
298 return value;
299 }
300
301 let scale = (1u32 << bits) as f32;
302
303 (value * scale).round() / scale
304 }
305}
306
307pub mod reduced_precision {
309 #[cfg(feature = "no-std")]
310 use alloc::{vec, vec::Vec};
311
312 #[derive(Debug, Clone, Copy, PartialEq)]
314 pub struct F16 {
315 bits: u16,
316 }
317
318 impl F16 {
319 pub fn from_f32(value: f32) -> Self {
320 let bits = if value.is_nan() {
322 0x7e00 } else if value.is_infinite() {
324 if value.is_sign_positive() {
325 0x7c00
326 } else {
327 0xfc00
328 }
329 } else if value == 0.0 {
330 if value.is_sign_positive() {
331 0x0000
332 } else {
333 0x8000
334 }
335 } else {
336 let abs_val = value.abs();
338 let sign = if value < 0.0 { 0x8000 } else { 0x0000 };
339
340 if abs_val < 6.1e-5 {
341 sign } else if abs_val > 65504.0 {
343 sign | 0x7c00 } else {
345 let exp = (abs_val.log2().floor() as i16 + 15).clamp(0, 31) as u16;
347 let mantissa =
348 ((abs_val / 2.0_f32.powi(exp as i32 - 15) - 1.0) * 1024.0) as u16 & 0x3ff;
349 sign | (exp << 10) | mantissa
350 }
351 };
352
353 Self { bits }
354 }
355
356 pub fn to_f32(self) -> f32 {
357 let sign = (self.bits & 0x8000) != 0;
358 let exp = (self.bits >> 10) & 0x1f;
359 let mantissa = self.bits & 0x3ff;
360
361 if exp == 0 {
362 if mantissa == 0 {
363 if sign {
364 -0.0
365 } else {
366 0.0
367 }
368 } else {
369 let value = (mantissa as f32) / 1024.0 * 2.0_f32.powi(-14);
371 if sign {
372 -value
373 } else {
374 value
375 }
376 }
377 } else if exp == 31 {
378 if mantissa == 0 {
379 if sign {
380 f32::NEG_INFINITY
381 } else {
382 f32::INFINITY
383 }
384 } else {
385 f32::NAN
386 }
387 } else {
388 let value = (1.0 + (mantissa as f32) / 1024.0) * 2.0_f32.powi(exp as i32 - 15);
389 if sign {
390 -value
391 } else {
392 value
393 }
394 }
395 }
396 }
397
398 pub struct U8Quantized {
400 scale: f32,
401 zero_point: u8,
402 }
403
404 impl U8Quantized {
405 pub fn new(min_val: f32, max_val: f32) -> Self {
406 let scale = (max_val - min_val) / 255.0;
407 let zero_point = (-min_val / scale).round().clamp(0.0, 255.0) as u8;
408
409 Self { scale, zero_point }
410 }
411
412 pub fn quantize(&self, value: f32) -> u8 {
413 ((value / self.scale) + self.zero_point as f32)
414 .round()
415 .clamp(0.0, 255.0) as u8
416 }
417
418 pub fn dequantize(&self, quantized: u8) -> f32 {
419 (quantized as f32 - self.zero_point as f32) * self.scale
420 }
421
422 pub fn quantized_dot_product(&self, a: &[u8], b: &[u8]) -> f32 {
423 let sum: i32 = a
424 .iter()
425 .zip(b.iter())
426 .map(|(&x, &y)| {
427 let x_adj = x as i32 - self.zero_point as i32;
428 let y_adj = y as i32 - self.zero_point as i32;
429 x_adj * y_adj
430 })
431 .sum();
432
433 sum as f32 * self.scale * self.scale
434 }
435 }
436
437 pub fn mixed_precision_matrix_multiply(
439 a: &[f32],
440 b: &[f32],
441 rows_a: usize,
442 cols_a: usize,
443 cols_b: usize,
444 ) -> Vec<f32> {
445 assert_eq!(a.len(), rows_a * cols_a);
446 assert_eq!(b.len(), cols_a * cols_b);
447
448 let mut result = vec![0.0f32; rows_a * cols_b];
449
450 let a_f16: Vec<F16> = a.iter().map(|&x| F16::from_f32(x)).collect();
452 let b_f16: Vec<F16> = b.iter().map(|&x| F16::from_f32(x)).collect();
453
454 for i in 0..rows_a {
455 for j in 0..cols_b {
456 let mut sum = 0.0f32;
457 for k in 0..cols_a {
458 let a_val = a_f16[i * cols_a + k].to_f32();
459 let b_val = b_f16[k * cols_b + j].to_f32();
460 sum += a_val * b_val;
461 }
462 result[i * cols_b + j] = sum;
463 }
464 }
465
466 result
467 }
468}
469
470pub mod probabilistic {
472 use super::*;
473 #[cfg(feature = "no-std")]
474 use alloc::{vec, vec::Vec};
475
476 pub struct CountMinSketch {
478 table: Vec<Vec<u32>>,
479 hash_functions: Vec<u64>,
480 width: usize,
481 #[allow(dead_code)] depth: usize,
483 }
484
485 impl CountMinSketch {
486 pub fn new(width: usize, depth: usize) -> Self {
487 use scirs2_core::random::thread_rng;
488 let mut rng = thread_rng();
489 let hash_functions: Vec<u64> = (0..depth).map(|_| rng.random::<u64>()).collect();
490
491 Self {
492 table: vec![vec![0; width]; depth],
493 hash_functions,
494 width,
495 depth,
496 }
497 }
498
499 pub fn update(&mut self, item: u64, count: u32) {
500 for (i, &hash_seed) in self.hash_functions.iter().enumerate() {
501 let hash = self.hash_item(item, hash_seed);
502 let index = (hash as usize) % self.width;
503 self.table[i][index] = self.table[i][index].saturating_add(count);
504 }
505 }
506
507 pub fn estimate(&self, item: u64) -> u32 {
508 self.hash_functions
509 .iter()
510 .enumerate()
511 .map(|(i, &hash_seed)| {
512 let hash = self.hash_item(item, hash_seed);
513 let index = (hash as usize) % self.width;
514 self.table[i][index]
515 })
516 .min()
517 .unwrap_or(0)
518 }
519
520 fn hash_item(&self, item: u64, seed: u64) -> u64 {
521 let mut hash = seed.wrapping_mul(14695981039346656037u64);
523 let bytes = item.to_le_bytes();
524 for byte in bytes {
525 hash ^= byte as u64;
526 hash = hash.wrapping_mul(1099511628211);
527 }
528 hash
529 }
530 }
531
532 pub struct HyperLogLog {
534 buckets: Vec<u8>,
535 bucket_count: usize,
536 alpha: f64,
537 }
538
539 impl HyperLogLog {
540 pub fn new(precision: u8) -> Self {
541 let bucket_count = 1 << precision;
542 let alpha = match bucket_count {
543 16 => 0.673,
544 32 => 0.697,
545 64 => 0.709,
546 _ => 0.7213 / (1.0 + 1.079 / bucket_count as f64),
547 };
548
549 Self {
550 buckets: vec![0; bucket_count],
551 bucket_count,
552 alpha,
553 }
554 }
555
556 pub fn add(&mut self, item: u64) {
557 let hash = self.hash_item(item);
558 let precision = self.bucket_count.trailing_zeros() as usize;
559 let bucket = (hash & ((self.bucket_count - 1) as u64)) as usize;
560 let remaining_hash = hash >> precision;
561 let leading_zeros = remaining_hash.leading_zeros() as u8 + 1;
562
563 self.buckets[bucket] = self.buckets[bucket].max(leading_zeros);
564 }
565
566 pub fn estimate(&self) -> f64 {
567 let raw_estimate = self.alpha * (self.bucket_count as f64).powi(2)
568 / self
569 .buckets
570 .iter()
571 .map(|&b| 2.0_f64.powi(-(b as i32)))
572 .sum::<f64>();
573
574 if raw_estimate <= 2.5 * self.bucket_count as f64 {
576 let zeros = self.buckets.iter().filter(|&&b| b == 0).count();
577 if zeros != 0 {
578 return (self.bucket_count as f64)
579 * (self.bucket_count as f64 / zeros as f64).ln();
580 }
581 }
582
583 raw_estimate
584 }
585
586 fn hash_item(&self, item: u64) -> u64 {
587 let mut hash = 14695981039346656037u64;
589 let bytes = item.to_le_bytes();
590 for byte in bytes {
591 hash ^= byte as u64;
592 hash = hash.wrapping_mul(1099511628211);
593 }
594 hash
595 }
596 }
597
598 pub struct BloomFilter {
600 bit_array: Vec<bool>,
601 hash_functions: Vec<u64>,
602 size: usize,
603 #[allow(dead_code)] hash_count: usize,
605 }
606
607 impl BloomFilter {
608 pub fn new(expected_elements: usize, false_positive_rate: f64) -> Self {
609 let size = (-(expected_elements as f64 * false_positive_rate.ln())
610 / (2.0_f64.ln().powi(2)))
611 .ceil() as usize;
612 let hash_count =
613 ((size as f64 / expected_elements as f64) * 2.0_f64.ln()).ceil() as usize;
614
615 use scirs2_core::random::thread_rng;
616 let mut rng = thread_rng();
617 let hash_functions: Vec<u64> = (0..hash_count).map(|_| rng.random::<u64>()).collect();
618
619 Self {
620 bit_array: vec![false; size],
621 hash_functions,
622 size,
623 hash_count,
624 }
625 }
626
627 pub fn add(&mut self, item: u64) {
628 for &hash_seed in &self.hash_functions {
629 let hash = self.hash_item(item, hash_seed);
630 let index = (hash as usize) % self.size;
631 self.bit_array[index] = true;
632 }
633 }
634
635 pub fn contains(&self, item: u64) -> bool {
636 self.hash_functions.iter().all(|&hash_seed| {
637 let hash = self.hash_item(item, hash_seed);
638 let index = (hash as usize) % self.size;
639 self.bit_array[index]
640 })
641 }
642
643 fn hash_item(&self, item: u64, seed: u64) -> u64 {
644 item.wrapping_mul(seed).wrapping_add(seed >> 32)
645 }
646 }
647}
648
649pub mod sketching {
651 use super::*;
652 #[cfg(feature = "no-std")]
653 use alloc::{vec, vec::Vec};
654
655 pub struct RandomProjection {
657 projection_matrix: Vec<f32>,
658 original_dim: usize,
659 projected_dim: usize,
660 }
661
662 impl RandomProjection {
663 pub fn new(original_dim: usize, projected_dim: usize, epsilon: f64) -> Self {
664 let min_dim =
666 (4.0 * (2.0 * epsilon.powi(2) - epsilon.powi(3) / 3.0).ln()).ceil() as usize;
667 assert!(
668 projected_dim >= min_dim,
669 "Projected dimension too small for given epsilon"
670 );
671
672 use scirs2_core::random::thread_rng;
673 let mut rng = thread_rng();
674 let scale = (projected_dim as f32).sqrt();
675
676 let projection_matrix: Vec<f32> = (0..original_dim * projected_dim)
677 .map(|_| {
678 let u1: f32 = rng.random::<f32>();
680 let u2: f32 = rng.random::<f32>();
681 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * core::f32::consts::PI * u2).cos();
682 z / scale
683 })
684 .collect();
685
686 Self {
687 projection_matrix,
688 original_dim,
689 projected_dim,
690 }
691 }
692
693 pub fn project(&self, vector: &[f32]) -> Vec<f32> {
694 assert_eq!(vector.len(), self.original_dim);
695
696 let mut result = vec![0.0f32; self.projected_dim];
697
698 for (j, result_j) in result.iter_mut().enumerate() {
699 for (i, &v) in vector.iter().enumerate() {
700 *result_j += v * self.projection_matrix[j * self.original_dim + i];
701 }
702 }
703
704 result
705 }
706
707 pub fn batch_project(&self, vectors: &[Vec<f32>]) -> Vec<Vec<f32>> {
708 vectors.iter().map(|v| self.project(v)).collect()
709 }
710 }
711
712 pub struct FrequentItemsSketch {
714 count_min: probabilistic::CountMinSketch,
715 heavy_hitters: HashMap<u64, u32>,
716 threshold: u32,
717 total_count: u64,
718 }
719
720 impl FrequentItemsSketch {
721 pub fn new(width: usize, depth: usize, threshold: u32) -> Self {
722 Self {
723 count_min: probabilistic::CountMinSketch::new(width, depth),
724 heavy_hitters: HashMap::new(),
725 threshold,
726 total_count: 0,
727 }
728 }
729
730 pub fn update(&mut self, item: u64, count: u32) {
731 self.count_min.update(item, count);
732 self.total_count += count as u64;
733
734 let estimated_count = self.count_min.estimate(item);
735 if estimated_count >= self.threshold {
736 *self.heavy_hitters.entry(item).or_insert(0) += count;
737 }
738 }
739
740 pub fn get_frequent_items(&self) -> Vec<(u64, u32)> {
741 self.heavy_hitters.iter().map(|(&k, &v)| (k, v)).collect()
742 }
743
744 pub fn estimate_frequency(&self, item: u64) -> f64 {
745 let count = if let Some(&exact_count) = self.heavy_hitters.get(&item) {
746 exact_count
747 } else {
748 self.count_min.estimate(item)
749 };
750
751 count as f64 / self.total_count as f64
752 }
753 }
754
755 pub struct QuantileSketch {
757 buckets: Vec<(f64, u64)>, max_buckets: usize,
759 total_count: u64,
760 }
761
762 impl QuantileSketch {
763 pub fn new(max_buckets: usize) -> Self {
764 Self {
765 buckets: Vec::new(),
766 max_buckets,
767 total_count: 0,
768 }
769 }
770
771 pub fn add(&mut self, value: f64) {
772 self.total_count += 1;
773
774 let pos = self
776 .buckets
777 .binary_search_by(|(v, _)| v.partial_cmp(&value).expect("operation should succeed"))
778 .unwrap_or_else(|e| e);
779
780 if pos < self.buckets.len() && (self.buckets[pos].0 - value).abs() < 1e-10 {
781 self.buckets[pos].1 += 1;
783 } else {
784 self.buckets.insert(pos, (value, 1));
786 }
787
788 if self.buckets.len() > self.max_buckets {
790 self.compress();
791 }
792 }
793
794 pub fn quantile(&self, q: f64) -> Option<f64> {
795 if self.buckets.is_empty() || !(0.0..=1.0).contains(&q) {
796 return None;
797 }
798
799 let target_rank = (q * self.total_count as f64) as u64;
800 let mut current_rank = 0;
801
802 for &(value, count) in &self.buckets {
803 current_rank += count;
804 if current_rank >= target_rank {
805 return Some(value);
806 }
807 }
808
809 self.buckets.last().map(|(v, _)| *v)
810 }
811
812 fn compress(&mut self) {
813 while self.buckets.len() > self.max_buckets {
815 let mut min_error = f64::INFINITY;
816 let mut merge_idx = 0;
817
818 for i in 0..self.buckets.len() - 1 {
819 let error = (self.buckets[i + 1].0 - self.buckets[i].0)
820 * (self.buckets[i].1 + self.buckets[i + 1].1) as f64;
821 if error < min_error {
822 min_error = error;
823 merge_idx = i;
824 }
825 }
826
827 let merged_count = self.buckets[merge_idx].1 + self.buckets[merge_idx + 1].1;
829 let merged_value = (self.buckets[merge_idx].0 * self.buckets[merge_idx].1 as f64
830 + self.buckets[merge_idx + 1].0 * self.buckets[merge_idx + 1].1 as f64)
831 / merged_count as f64;
832
833 self.buckets[merge_idx] = (merged_value, merged_count);
834 self.buckets.remove(merge_idx + 1);
835 }
836 }
837 }
838}
839
840#[allow(non_snake_case)]
841#[cfg(all(test, not(feature = "no-std")))]
842mod tests {
843 use super::*;
844 #[cfg(feature = "no-std")]
845 use alloc::{vec, vec::Vec};
846 use approx::assert_abs_diff_eq;
847
848 #[test]
849 fn test_approximate_dot_product() {
850 let a = vec![1.0, 2.0, 3.0, 4.0];
851 let b = vec![5.0, 6.0, 7.0, 8.0];
852 let expected = 70.0; let (result, _error) =
855 approximate_ops::approximate_dot_product_f32(&a, &b, ErrorBound::MODERATE);
856 assert_abs_diff_eq!(result, expected, epsilon = 1.0);
857 }
858
859 #[test]
860 fn test_approximate_sum() {
861 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
862 let expected = 15.0;
863
864 let (result, _error) = approximate_ops::approximate_sum_f32(&data, ErrorBound::MODERATE);
865 assert_abs_diff_eq!(result, expected, epsilon = 0.1);
866 }
867
868 #[test]
869 fn test_f16_conversion() {
870 let values = vec![0.0, 1.0, -1.0, 10.5, -10.5];
871
872 for &val in &values {
873 let f16_val = reduced_precision::F16::from_f32(val);
874 let converted_back = f16_val.to_f32();
875 assert_abs_diff_eq!(converted_back, val, epsilon = 0.1);
876 }
877 }
878
879 #[test]
880 fn test_u8_quantization() {
881 let quantizer = reduced_precision::U8Quantized::new(-10.0, 10.0);
882
883 let values = vec![-10.0, 0.0, 10.0, 5.0, -5.0];
884 for &val in &values {
885 let quantized = quantizer.quantize(val);
886 let dequantized = quantizer.dequantize(quantized);
887 assert_abs_diff_eq!(dequantized, val, epsilon = 0.2);
888 }
889 }
890
891 #[test]
892 fn test_count_min_sketch() {
893 let mut sketch = probabilistic::CountMinSketch::new(100, 5);
894
895 sketch.update(42, 10);
896 sketch.update(42, 5);
897 sketch.update(100, 3);
898
899 assert!(sketch.estimate(42) >= 15);
900 assert!(sketch.estimate(100) >= 3);
901 assert_eq!(sketch.estimate(999), 0);
902 }
903
904 #[test]
905 fn test_hyperloglog() {
906 let mut hll = probabilistic::HyperLogLog::new(10);
907
908 for i in 0..1000 {
910 hll.add(i);
911 }
912
913 let estimate = hll.estimate();
914 assert!((100.0..=10000.0).contains(&estimate)); }
916
917 #[test]
918 fn test_bloom_filter() {
919 let mut bloom = probabilistic::BloomFilter::new(1000, 0.01);
920
921 for i in 0..100 {
923 bloom.add(i);
924 }
925
926 for i in 0..100 {
928 assert!(bloom.contains(i));
929 }
930
931 let mut false_positives = 0;
933 for i in 100..200 {
934 if bloom.contains(i) {
935 false_positives += 1;
936 }
937 }
938
939 assert!(false_positives < 5); }
941
942 #[test]
943 fn test_random_projection() {
944 let projection = sketching::RandomProjection::new(100, 20, 0.1);
945
946 let vector = (0..100).map(|i| i as f32).collect::<Vec<f32>>();
947 let projected = projection.project(&vector);
948
949 assert_eq!(projected.len(), 20);
950
951 let vector2 = (0..100).map(|i| (i * 2) as f32).collect::<Vec<f32>>();
953 let projected2 = projection.project(&vector2);
954
955 let correlation = projected
957 .iter()
958 .zip(projected2.iter())
959 .map(|(a, b)| a * b)
960 .sum::<f32>();
961
962 assert!(correlation > 0.0);
963 }
964
965 #[test]
966 fn test_quantile_sketch() {
967 let mut sketch = sketching::QuantileSketch::new(20);
968
969 for i in 1..=100 {
971 sketch.add(i as f64);
972 }
973
974 let median = sketch.quantile(0.5).expect("operation should succeed");
976 assert!((45.0..=55.0).contains(&median));
977
978 let q90 = sketch.quantile(0.9).expect("operation should succeed");
979 assert!((85.0..=95.0).contains(&q90));
980 }
981
982 #[test]
983 fn test_frequent_items_sketch() {
984 let mut sketch = sketching::FrequentItemsSketch::new(100, 5, 5); for _ in 0..20 {
988 sketch.update(42, 1);
989 }
990 for _ in 0..15 {
991 sketch.update(100, 1);
992 }
993 for _ in 0..5 {
994 sketch.update(200, 1);
995 }
996
997 let frequent = sketch.get_frequent_items();
998 assert!(!frequent.is_empty()); let freq_42 = sketch.estimate_frequency(42);
1002 assert!(freq_42 >= 0.3); }
1004
1005 #[test]
1006 fn test_mixed_precision_matrix_multiply() {
1007 let a = vec![1.0, 2.0, 3.0, 4.0]; let b = vec![5.0, 6.0, 7.0, 8.0]; let result = reduced_precision::mixed_precision_matrix_multiply(&a, &b, 2, 2, 2);
1011
1012 let expected = [19.0, 22.0, 43.0, 50.0];
1014
1015 for (actual, expected) in result.iter().zip(expected.iter()) {
1016 assert_abs_diff_eq!(*actual, *expected, epsilon = 1.0);
1017 }
1018 }
1019}