use crate::types::{Array, ArraySlice};
use log::debug;
use measure_time_macro::log_time;
use rand::rngs::SmallRng;
use rand::SeedableRng;
#[log_time]
pub(crate) fn sample(queries: &ArraySlice, n: usize, shape: usize, seed: u64) -> Array {
debug!(n=n ;"sampling");
assert!(!queries.is_empty(), "Queries cannot be empty");
assert!(queries.len().is_multiple_of(shape));
assert!(n > 0, "Sample size must be greater than zero");
let num_queries = queries.len() / shape;
if n >= num_queries {
return queries.to_vec();
}
let mut rng = SmallRng::seed_from_u64(seed); let idxs = rand::seq::index::sample(&mut rng, num_queries, n).into_vec();
let mut out = Array::with_capacity(n * shape);
for &i in &idxs {
let start = i * shape;
let end = start + shape;
let slice = &queries[start..end];
out.extend_from_slice(slice);
}
out
}
pub(crate) fn select_sample_size(k: usize, total_objects: usize, sample_threshold: usize) -> usize {
let base_size = k * 40;
if base_size < sample_threshold {
total_objects.min(sample_threshold)
} else {
total_objects.min(base_size)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::ArrayNumType;
#[test]
fn test_sample_basic_functionality() {
let queries = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
];
let sampled = sample(&queries, 3, 3, 42);
assert_eq!(sampled.len(), 9);
for value in &sampled {
assert!((1.0..=15.0).contains(value));
}
}
#[test]
fn test_sample_single_query() {
let queries = vec![1.0, 2.0, 3.0];
let sampled = sample(&queries, 1, 3, 42);
assert_eq!(sampled.len(), 3);
assert_eq!(sampled, vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_sample_all_queries() {
let queries = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
];
let original_count = queries.len() / 3;
let sampled = sample(&queries, original_count, 3, 42);
assert_eq!(sampled.len(), queries.len());
}
#[test]
fn test_sample_more_than_available() {
let queries = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
];
let sampled = sample(&queries, 10, 3, 42);
assert_eq!(sampled.len(), queries.len());
}
#[test]
fn test_sample_preserves_array_structure() {
let queries = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let sampled = sample(&queries, 2, 2, 42);
assert_eq!(sampled.len(), 4);
for chunk in sampled.chunks(2) {
assert_eq!(chunk.len(), 2);
let diff = chunk[1] - chunk[0];
assert!((diff - 1.0).abs() < f32::EPSILON);
}
}
#[test]
fn test_sample_different_sized_shape() {
let queries = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let sampled = sample(&queries, 1, 4, 42);
assert_eq!(sampled.len(), 4);
}
#[test]
fn test_sample_randomness() {
let queries = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
];
let sample1 = sample(&queries, 3, 3, 42);
let sample2 = sample(&queries, 3, 3, 42);
assert_eq!(sample1, sample2);
}
#[test]
fn test_sample_maintains_value_integrity() {
let queries = vec![1.1, 2.2, 3.3, 4.4, 5.5, 6.6];
let sampled = sample(&queries, 2, 3, 42);
assert_eq!(sampled.len(), 6);
for value in &sampled {
assert!(queries.contains(value));
}
}
#[test]
#[should_panic(expected = "Queries cannot be empty")]
fn test_sample_empty_queries_panics() {
let queries: Array = vec![];
sample(&queries, 1, 1, 42);
}
#[test]
#[should_panic(expected = "Sample size must be greater than zero")]
fn test_sample_zero_size_panics() {
let queries = vec![1.0, 2.0, 3.0];
sample(&queries, 0, 3, 42);
}
#[test]
fn test_sample_single_element_shape() {
let queries = vec![1.0, 2.0, 3.0, 4.0];
let sampled = sample(&queries, 2, 1, 42);
assert_eq!(sampled.len(), 2);
for value in &sampled {
assert!((1.0..=4.0).contains(value));
}
}
#[test]
fn test_sample_large_arrays() {
let queries: Array = (0..300).map(|i| i as ArrayNumType).collect();
let sampled = sample(&queries, 2, 100, 42);
assert_eq!(sampled.len(), 200); }
#[test]
fn test_sample_deterministic_with_fixed_seed() {
let queries = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
];
let first_run = sample(&queries, 3, 3, 42);
let second_run = sample(&queries, 3, 3, 42);
assert_eq!(first_run, second_run);
}
}