use std::simd::prelude::*;
use std::simd::Simd;
const F32_LANES: usize = 8;
const U8_LANES: usize = 32;
type F32x8 = Simd<f32, F32_LANES>;
type U8x32 = Simd<u8, U8_LANES>;
#[inline]
pub fn sum_f32(data: &[f32]) -> f32 {
if data.is_empty() {
return 0.0;
}
let chunks = data.chunks_exact(F32_LANES);
let remainder = chunks.remainder();
let mut acc = F32x8::splat(0.0);
for chunk in chunks {
let simd_chunk = F32x8::from_slice(chunk);
acc += simd_chunk;
}
let mut sum = acc.reduce_sum();
for &val in remainder {
sum += val;
}
sum
}
#[inline]
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(
a.len(),
b.len(),
"dot_product requires equal-length slices"
);
let len = a.len().min(b.len());
if len == 0 {
return 0.0;
}
let a_chunks = a[..len].chunks_exact(F32_LANES);
let b_chunks = b[..len].chunks_exact(F32_LANES);
let a_remainder = a_chunks.remainder();
let b_remainder = b_chunks.remainder();
let mut acc = F32x8::splat(0.0);
for (chunk_a, chunk_b) in a_chunks.zip(b_chunks) {
let simd_a = F32x8::from_slice(chunk_a);
let simd_b = F32x8::from_slice(chunk_b);
acc += simd_a * simd_b;
}
let mut sum = acc.reduce_sum();
for (&va, &vb) in a_remainder.iter().zip(b_remainder.iter()) {
sum += va * vb;
}
sum
}
#[inline]
pub fn count_byte(data: &[u8], byte: u8) -> usize {
if data.is_empty() {
return 0;
}
let target = U8x32::splat(byte);
let chunks = data.chunks_exact(U8_LANES);
let remainder = chunks.remainder();
let mut count: usize = 0;
for chunk in chunks {
let simd_chunk = U8x32::from_slice(chunk);
let matches = simd_chunk.simd_eq(target);
count += matches.to_bitmask().count_ones() as usize;
}
for &b in remainder {
if b == byte {
count += 1;
}
}
count
}
#[inline]
pub fn all_equal(data: &[u8]) -> bool {
if data.len() <= 1 {
return true;
}
let first = data[0];
let target = U8x32::splat(first);
let chunks = data.chunks_exact(U8_LANES);
let remainder = chunks.remainder();
for chunk in chunks {
let simd_chunk = U8x32::from_slice(chunk);
let matches = simd_chunk.simd_eq(target);
if !matches.all() {
return false;
}
}
for &b in remainder {
if b != first {
return false;
}
}
true
}
#[inline]
pub fn find_newlines(data: &[u8]) -> Vec<usize> {
if data.is_empty() {
return Vec::new();
}
let estimated_lines = (data.len() / 40).max(1);
let mut positions = Vec::with_capacity(estimated_lines);
let newline = U8x32::splat(b'\n');
let chunks = data.chunks_exact(U8_LANES);
let remainder = chunks.remainder();
let remainder_start = data.len() - remainder.len();
for (chunk_idx, chunk) in chunks.enumerate() {
let simd_chunk = U8x32::from_slice(chunk);
let matches = simd_chunk.simd_eq(newline);
let mask = matches.to_bitmask();
if mask != 0 {
let base = chunk_idx * U8_LANES;
let mut remaining_mask = mask;
while remaining_mask != 0 {
let lane = remaining_mask.trailing_zeros() as usize;
positions.push(base + lane);
remaining_mask &= remaining_mask - 1; }
}
}
for (i, &b) in remainder.iter().enumerate() {
if b == b'\n' {
positions.push(remainder_start + i);
}
}
positions
}
#[inline]
pub fn squared_norm(data: &[f32]) -> f32 {
dot_product(data, data)
}
#[inline]
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot = dot_product(a, b);
let norm_a = squared_norm(a);
let norm_b = squared_norm(b);
let denominator = (norm_a * norm_b).sqrt();
if denominator < f32::EPSILON {
return 0.0;
}
dot / denominator
}
const U32_LANES: usize = 8;
type U32x8 = Simd<u32, U32_LANES>;
#[inline]
pub fn find_matching_u32(data: &[u32], target: u32) -> Vec<usize> {
if data.is_empty() {
return Vec::new();
}
let estimated_matches = (data.len() / 10).max(4);
let mut indices = Vec::with_capacity(estimated_matches);
let target_simd = U32x8::splat(target);
let chunks = data.chunks_exact(U32_LANES);
let remainder = chunks.remainder();
let remainder_start = data.len() - remainder.len();
for (chunk_idx, chunk) in chunks.enumerate() {
let simd_chunk = U32x8::from_slice(chunk);
let matches = simd_chunk.simd_eq(target_simd);
let mask = matches.to_bitmask();
if mask != 0 {
let base = chunk_idx * U32_LANES;
let mut remaining_mask = mask;
while remaining_mask != 0 {
let lane = remaining_mask.trailing_zeros() as usize;
indices.push(base + lane);
remaining_mask &= remaining_mask - 1; }
}
}
for (i, &val) in remainder.iter().enumerate() {
if val == target {
indices.push(remainder_start + i);
}
}
indices
}
#[inline]
pub fn find_matching_u32_into(data: &[u32], target: u32, output: &mut Vec<usize>) {
output.clear();
if data.is_empty() {
return;
}
let target_simd = U32x8::splat(target);
let chunks = data.chunks_exact(U32_LANES);
let remainder = chunks.remainder();
let remainder_start = data.len() - remainder.len();
for (chunk_idx, chunk) in chunks.enumerate() {
let simd_chunk = U32x8::from_slice(chunk);
let matches = simd_chunk.simd_eq(target_simd);
let mask = matches.to_bitmask();
if mask != 0 {
let base = chunk_idx * U32_LANES;
let mut remaining_mask = mask;
while remaining_mask != 0 {
let lane = remaining_mask.trailing_zeros() as usize;
output.push(base + lane);
remaining_mask &= remaining_mask - 1;
}
}
}
for (i, &val) in remainder.iter().enumerate() {
if val == target {
output.push(remainder_start + i);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sum_f32_empty() {
assert_eq!(sum_f32(&[]), 0.0);
}
#[test]
fn test_sum_f32_single() {
assert!((sum_f32(&[42.0]) - 42.0).abs() < 1e-6);
}
#[test]
fn test_sum_f32_exact_lanes() {
let data: Vec<f32> = (1..=8).map(|x| x as f32).collect();
assert!((sum_f32(&data) - 36.0).abs() < 1e-6);
}
#[test]
fn test_sum_f32_with_remainder() {
let data: Vec<f32> = (1..=11).map(|x| x as f32).collect();
assert!((sum_f32(&data) - 66.0).abs() < 1e-6);
}
#[test]
fn test_sum_f32_large() {
let data: Vec<f32> = (1..=1000).map(|x| x as f32).collect();
assert!((sum_f32(&data) - 500_500.0).abs() < 1e-2);
}
#[test]
fn test_dot_product_empty() {
assert_eq!(dot_product(&[], &[]), 0.0);
}
#[test]
fn test_dot_product_single() {
assert!((dot_product(&[3.0], &[4.0]) - 12.0).abs() < 1e-6);
}
#[test]
fn test_dot_product_basic() {
let a = [1.0_f32, 2.0, 3.0];
let b = [4.0_f32, 5.0, 6.0];
assert!((dot_product(&a, &b) - 32.0).abs() < 1e-6);
}
#[test]
fn test_dot_product_exact_lanes() {
let a: Vec<f32> = vec![1.0; 8];
let b: Vec<f32> = vec![2.0; 8];
assert!((dot_product(&a, &b) - 16.0).abs() < 1e-6);
}
#[test]
fn test_dot_product_with_remainder() {
let a: Vec<f32> = vec![1.0; 11];
let b: Vec<f32> = vec![2.0; 11];
assert!((dot_product(&a, &b) - 22.0).abs() < 1e-6);
}
#[test]
fn test_dot_product_large_embeddings() {
let a: Vec<f32> = (0..1024).map(|i| (i as f32) * 0.001).collect();
let b: Vec<f32> = (0..1024).map(|i| (1024 - i) as f32 * 0.001).collect();
let result = dot_product(&a, &b);
assert!(result.is_finite());
assert!(result > 0.0);
}
#[test]
fn test_count_byte_empty() {
assert_eq!(count_byte(&[], b'x'), 0);
}
#[test]
fn test_count_byte_not_found() {
assert_eq!(count_byte(b"hello world", b'x'), 0);
}
#[test]
fn test_count_byte_found() {
assert_eq!(count_byte(b"hello", b'l'), 2);
assert_eq!(count_byte(b"hello", b'h'), 1);
assert_eq!(count_byte(b"hello", b'o'), 1);
}
#[test]
fn test_count_byte_newlines() {
let text = b"line1\nline2\nline3\n";
assert_eq!(count_byte(text, b'\n'), 3);
}
#[test]
fn test_count_byte_exact_lanes() {
let data = vec![b'a'; 32];
assert_eq!(count_byte(&data, b'a'), 32);
assert_eq!(count_byte(&data, b'b'), 0);
}
#[test]
fn test_count_byte_with_remainder() {
let data = vec![b'a'; 37];
assert_eq!(count_byte(&data, b'a'), 37);
}
#[test]
fn test_count_byte_large() {
let mut data = vec![b' '; 10000];
data[0] = b'x';
data[5000] = b'x';
data[9999] = b'x';
assert_eq!(count_byte(&data, b'x'), 3);
}
#[test]
fn test_all_equal_empty() {
assert!(all_equal(&[]));
}
#[test]
fn test_all_equal_single() {
assert!(all_equal(&[b'a']));
}
#[test]
fn test_all_equal_true() {
assert!(all_equal(b"aaaaaa"));
assert!(all_equal(&[0u8; 100]));
assert!(all_equal(&[255u8; 100]));
}
#[test]
fn test_all_equal_false_early() {
assert!(!all_equal(b"ab"));
assert!(!all_equal(b"ba"));
}
#[test]
fn test_all_equal_false_late() {
let mut data = vec![b'a'; 100];
data[99] = b'b';
assert!(!all_equal(&data));
}
#[test]
fn test_all_equal_exact_lanes() {
assert!(all_equal(&[b'x'; 32]));
let mut data = vec![b'x'; 32];
data[31] = b'y';
assert!(!all_equal(&data));
}
#[test]
fn test_find_newlines_empty() {
assert_eq!(find_newlines(&[]), Vec::<usize>::new());
}
#[test]
fn test_find_newlines_none() {
assert_eq!(find_newlines(b"no newlines here"), Vec::<usize>::new());
}
#[test]
fn test_find_newlines_single() {
assert_eq!(find_newlines(b"hello\nworld"), vec![5]);
}
#[test]
fn test_find_newlines_multiple() {
assert_eq!(find_newlines(b"a\nb\nc\n"), vec![1, 3, 5]);
}
#[test]
fn test_find_newlines_at_boundaries() {
assert_eq!(find_newlines(b"\nstart"), vec![0]);
assert_eq!(find_newlines(b"end\n"), vec![3]);
assert_eq!(find_newlines(b"\n"), vec![0]);
}
#[test]
fn test_find_newlines_consecutive() {
assert_eq!(find_newlines(b"\n\n\n"), vec![0, 1, 2]);
}
#[test]
fn test_find_newlines_exact_lanes() {
let mut data = vec![b'x'; 32];
data[0] = b'\n';
data[15] = b'\n';
data[31] = b'\n';
assert_eq!(find_newlines(&data), vec![0, 15, 31]);
}
#[test]
fn test_find_newlines_with_remainder() {
let mut data = vec![b'x'; 37];
data[5] = b'\n';
data[35] = b'\n';
assert_eq!(find_newlines(&data), vec![5, 35]);
}
#[test]
fn test_squared_norm() {
let v = [3.0_f32, 4.0];
assert!((squared_norm(&v) - 25.0).abs() < 1e-6);
}
#[test]
fn test_squared_norm_empty() {
assert_eq!(squared_norm(&[]), 0.0);
}
#[test]
fn test_cosine_similarity_identical() {
let v = [1.0_f32, 2.0, 3.0];
assert!((cosine_similarity(&v, &v) - 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let a = [1.0_f32, 0.0];
let b = [0.0_f32, 1.0];
assert!(cosine_similarity(&a, &b).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_opposite() {
let a = [1.0_f32, 0.0];
let b = [-1.0_f32, 0.0];
assert!((cosine_similarity(&a, &b) + 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_zero_vector() {
let a = [0.0_f32, 0.0];
let b = [1.0_f32, 1.0];
assert_eq!(cosine_similarity(&a, &b), 0.0);
}
#[test]
fn test_find_matching_u32_empty() {
assert_eq!(find_matching_u32(&[], 5), Vec::<usize>::new());
}
#[test]
fn test_find_matching_u32_not_found() {
let data = [1u32, 2, 3, 4, 6, 7, 8, 9];
assert_eq!(find_matching_u32(&data, 5), Vec::<usize>::new());
}
#[test]
fn test_find_matching_u32_single() {
let data = [1u32, 5, 3, 4, 6, 7, 8, 9];
assert_eq!(find_matching_u32(&data, 5), vec![1]);
}
#[test]
fn test_find_matching_u32_multiple() {
let data = [1u32, 5, 3, 5, 5, 2, 7, 5, 9, 5];
assert_eq!(find_matching_u32(&data, 5), vec![1, 3, 4, 7, 9]);
}
#[test]
fn test_find_matching_u32_exact_lanes() {
let data = [5u32, 1, 5, 2, 3, 5, 4, 5];
assert_eq!(find_matching_u32(&data, 5), vec![0, 2, 5, 7]);
}
#[test]
fn test_find_matching_u32_with_remainder() {
let data = [1u32, 2, 3, 4, 5, 6, 7, 8, 9, 5, 5];
assert_eq!(find_matching_u32(&data, 5), vec![4, 9, 10]);
}
#[test]
fn test_find_matching_u32_all_match() {
let data = [5u32; 16];
let expected: Vec<usize> = (0..16).collect();
assert_eq!(find_matching_u32(&data, 5), expected);
}
#[test]
fn test_find_matching_u32_large() {
let mut data = vec![0u32; 1000];
data[0] = 42;
data[100] = 42;
data[500] = 42;
data[999] = 42;
assert_eq!(find_matching_u32(&data, 42), vec![0, 100, 500, 999]);
}
#[test]
fn test_find_matching_u32_into_reuse() {
let data1 = [1u32, 5, 3, 5];
let data2 = [5u32, 5, 5, 1, 1, 5];
let mut buffer = Vec::new();
find_matching_u32_into(&data1, 5, &mut buffer);
assert_eq!(buffer, vec![1, 3]);
find_matching_u32_into(&data2, 5, &mut buffer);
assert_eq!(buffer, vec![0, 1, 2, 5]);
}
}