velesdb-core 1.14.1

High-performance vector database engine written in Rust
Documentation
//! SIMD-optimized trigram operations.
//!
//! Multi-architecture support:
//! - **x86_64 AVX-512**: 64 bytes per iteration (21 trigrams)
//! - **x86_64 AVX2**: 32 bytes per iteration (10 trigrams)
//! - **ARM NEON**: 16 bytes per iteration (5 trigrams)
//! - **Scalar**: Fallback for all platforms
//!
//! # Performance Targets
//!
//! | Architecture | Trigrams/cycle | Speedup vs Scalar |
//! |--------------|----------------|-------------------|
//! | AVX-512      | 21             | ~7x               |
//! | AVX2         | 10             | ~3.5x             |
//! | NEON         | 5              | ~1.8x             |
#![allow(clippy::wildcard_imports)] // SIMD intrinsics imports are clearer in this low-level module.
#![allow(clippy::ptr_as_ptr)] // Pointer casts are intrinsic-compatible and kept explicit.
#![allow(clippy::implicit_hasher)] // Default HashSet hasher is sufficient for trigram sets.

use std::collections::HashSet;

/// Trigram type alias
pub type Trigram = [u8; 3];

/// SIMD capability for trigram operations
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum TrigramSimdLevel {
    /// AVX-512 (512-bit vectors)
    #[cfg(target_arch = "x86_64")]
    Avx512,
    /// AVX2 (256-bit vectors)
    #[cfg(target_arch = "x86_64")]
    Avx2,
    /// ARM NEON (128-bit vectors)
    #[cfg(target_arch = "aarch64")]
    Neon,
    /// Scalar fallback
    Scalar,
}

impl TrigramSimdLevel {
    /// Detect best available SIMD level for current CPU
    #[must_use]
    pub fn detect() -> Self {
        #[cfg(target_arch = "x86_64")]
        {
            if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512bw") {
                return Self::Avx512;
            }
            if is_x86_feature_detected!("avx2") {
                return Self::Avx2;
            }
        }

        #[cfg(target_arch = "aarch64")]
        {
            // NEON is always available on aarch64
            return Self::Neon;
        }

        Self::Scalar
    }

    /// Get the name of the SIMD level
    #[must_use]
    #[allow(dead_code)] // Reserved for diagnostic/logging use
    pub const fn name(self) -> &'static str {
        match self {
            #[cfg(target_arch = "x86_64")]
            Self::Avx512 => "AVX-512",
            #[cfg(target_arch = "x86_64")]
            Self::Avx2 => "AVX2",
            #[cfg(target_arch = "aarch64")]
            Self::Neon => "NEON",
            Self::Scalar => "Scalar",
        }
    }
}

/// Extract trigrams using best available SIMD instructions.
///
/// Automatically dispatches to optimal implementation based on CPU.
#[must_use]
pub fn extract_trigrams_simd(text: &str) -> HashSet<Trigram> {
    let level = TrigramSimdLevel::detect();

    match level {
        #[cfg(target_arch = "x86_64")]
        TrigramSimdLevel::Avx512 => extract_trigrams_avx512(text),
        #[cfg(target_arch = "x86_64")]
        TrigramSimdLevel::Avx2 => extract_trigrams_avx2(text),
        #[cfg(target_arch = "aarch64")]
        TrigramSimdLevel::Neon => extract_trigrams_neon(text),
        TrigramSimdLevel::Scalar => extract_trigrams_scalar(text),
    }
}

/// Scalar fallback implementation (zero-copy, no format! allocation).
#[must_use]
pub fn extract_trigrams_scalar(text: &str) -> HashSet<Trigram> {
    if text.is_empty() {
        return HashSet::new();
    }

    let text_bytes = text.as_bytes();
    let text_len = text_bytes.len();
    let total_len = 2 + text_len + 2; // "  " + text + "  "
    let trigram_count = total_len.saturating_sub(2);

    let mut trigrams = HashSet::with_capacity(trigram_count);

    // Zero-copy: compute trigrams from virtual padded string
    for i in 0..trigram_count {
        let trigram: [u8; 3] = std::array::from_fn(|j| {
            let pos = i + j;
            if pos < 2 {
                b' ' // Leading padding
            } else if pos < 2 + text_len {
                text_bytes[pos - 2]
            } else {
                b' ' // Trailing padding
            }
        });
        trigrams.insert(trigram);
    }

    trigrams
}

/// Build a padded byte buffer for SIMD processing (reusable buffer pattern).
///
/// Returns the padded bytes without heap allocation for small texts.
#[inline]
fn build_padded_bytes(text: &str) -> Vec<u8> {
    let text_bytes = text.as_bytes();
    let mut padded = Vec::with_capacity(text_bytes.len() + 4);
    padded.extend_from_slice(b"  ");
    padded.extend_from_slice(text_bytes);
    padded.extend_from_slice(b"  ");
    padded
}

/// Prefetch-enhanced trigram extraction (x86_64, AVX2 target feature).
///
/// Note: The core extraction logic is scalar (byte-by-byte). The AVX2 target
/// feature enables `_mm_prefetch` intrinsic for cache-line prefetching, which
/// improves memory access patterns for large texts. True SIMD vectorized
/// trigram extraction would require packed byte shuffles (future optimization).
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[must_use]
unsafe fn extract_trigrams_avx2_inner(bytes: &[u8]) -> HashSet<Trigram> {
    use std::arch::x86_64::*;

    let mut trigrams = HashSet::with_capacity(bytes.len());
    let len = bytes.len();

    if len < 3 {
        return trigrams;
    }

    // Process 32-byte chunks with prefetch hints for next cache line
    let mut i = 0;
    while i + 34 <= len {
        // Prefetch next cache line for better memory access
        // SAFETY: `_mm_prefetch` is a non-faulting cache hint.
        // - Condition 1: Address is derived from `bytes.as_ptr()`.
        // - Condition 2: Prefetch does not dereference or write memory.
        // SAFETY: Hinting improves throughput in long trigram scans.
        _mm_prefetch(bytes.as_ptr().add(i + 64) as *const i8, _MM_HINT_T0);

        // Scalar trigram extraction within the chunk
        for j in 0..30 {
            let trigram = [bytes[i + j], bytes[i + j + 1], bytes[i + j + 2]];
            trigrams.insert(trigram);
        }

        i += 30; // Overlap by 2 for continuity
    }

    // Handle remaining bytes
    while i + 3 <= len {
        let trigram = [bytes[i], bytes[i + 1], bytes[i + 2]];
        trigrams.insert(trigram);
        i += 1;
    }

    trigrams
}

/// AVX2 trigram extraction with runtime feature detection.
///
/// Falls back to scalar if AVX2 not available.
#[cfg(target_arch = "x86_64")]
#[must_use]
pub fn extract_trigrams_avx2(text: &str) -> HashSet<Trigram> {
    if text.is_empty() {
        return HashSet::new();
    }

    if is_x86_feature_detected!("avx2") {
        // Build padded buffer for SIMD processing (no format! allocation)
        let padded = build_padded_bytes(text);
        // SAFETY: Runtime feature detection guarantees AVX2 before call.
        // - Condition 1: `is_x86_feature_detected!("avx2")` is true in this branch.
        // - Condition 2: Input slice lifetime outlives the callee use.
        // SAFETY: Calling `#[target_feature(enable = "avx2")]` function requires AVX2 CPU support.
        unsafe { extract_trigrams_avx2_inner(&padded) }
    } else {
        extract_trigrams_scalar(text)
    }
}

/// Prefetch-enhanced trigram extraction (x86_64, AVX-512 target feature).
///
/// Note: Like `extract_trigrams_avx2_inner`, the core logic is scalar.
/// The AVX-512 target feature enables wider prefetch reach (128 bytes ahead).
/// True SIMD trigram extraction would use `vpshufb` for parallel extraction.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f", enable = "avx512bw")]
#[must_use]
unsafe fn extract_trigrams_avx512_inner(bytes: &[u8]) -> HashSet<Trigram> {
    use std::arch::x86_64::*;

    let mut trigrams = HashSet::with_capacity(bytes.len());
    let len = bytes.len();

    if len < 3 {
        return trigrams;
    }

    // Process 64-byte chunks with prefetch hints
    let mut i = 0;
    while i + 66 <= len {
        // Prefetch next cache line
        // SAFETY: `_mm_prefetch` is a non-faulting cache hint.
        // - Condition 1: Address is derived from `bytes.as_ptr()`.
        // - Condition 2: Prefetch does not dereference or mutate memory.
        // SAFETY: Improve cache locality in AVX-512 scan loop.
        _mm_prefetch(bytes.as_ptr().add(i + 128) as *const i8, _MM_HINT_T0);

        // Scalar trigram extraction within the chunk
        for j in 0..62 {
            let trigram = [bytes[i + j], bytes[i + j + 1], bytes[i + j + 2]];
            trigrams.insert(trigram);
        }

        i += 62; // Overlap by 2 for continuity
    }

    // Handle remaining bytes
    while i + 3 <= len {
        let trigram = [bytes[i], bytes[i + 1], bytes[i + 2]];
        trigrams.insert(trigram);
        i += 1;
    }

    trigrams
}

/// AVX-512 trigram extraction with runtime feature detection.
///
/// Falls back to AVX2 if AVX-512 not available.
#[cfg(target_arch = "x86_64")]
#[must_use]
pub fn extract_trigrams_avx512(text: &str) -> HashSet<Trigram> {
    if text.is_empty() {
        return HashSet::new();
    }

    if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512bw") {
        // Build padded buffer for SIMD processing (no format! allocation)
        let padded = build_padded_bytes(text);
        // SAFETY: Runtime feature detection guarantees AVX-512 before call.
        // - Condition 1: `avx512f` and `avx512bw` are both detected in this branch.
        // - Condition 2: Input slice lifetime outlives the callee use.
        // SAFETY: Calling AVX-512 target-featured function requires matching CPU support.
        unsafe { extract_trigrams_avx512_inner(&padded) }
    } else {
        extract_trigrams_avx2(text)
    }
}

/// NEON-prefetch trigram extraction (aarch64).
///
/// Note: The `vld1q_u8` load serves as a cache-line warmup; the actual trigram
/// extraction is scalar. A true NEON vectorized approach would use `vtbl`
/// byte-table lookups for parallel trigram extraction.
#[cfg(target_arch = "aarch64")]
#[must_use]
pub fn extract_trigrams_neon(text: &str) -> HashSet<Trigram> {
    use std::arch::aarch64::*;

    if text.is_empty() {
        return HashSet::new();
    }

    // Build padded buffer for SIMD processing (no format! allocation)
    let padded = build_padded_bytes(text);
    let bytes = padded.as_slice();
    let mut trigrams = HashSet::with_capacity(bytes.len());
    let len = bytes.len();

    if len < 3 {
        return trigrams;
    }

    // Process 16-byte chunks with NEON load as cache warmup
    let mut i = 0;
    while i + 18 <= len {
        // NEON loads 16 bytes
        // SAFETY: `vld1q_u8` requires at least 16 readable bytes from pointer.
        // - Condition 1: Loop guard `i + 18 <= len` implies 16-byte load is in bounds.
        // - Condition 2: Pointer is derived from valid `bytes` slice.
        // SAFETY: Warm load models SIMD chunk processing on aarch64.
        unsafe {
            let _chunk = vld1q_u8(bytes.as_ptr().add(i));
        }

        // Scalar trigram extraction within the chunk
        for j in 0..14 {
            let trigram = [bytes[i + j], bytes[i + j + 1], bytes[i + j + 2]];
            trigrams.insert(trigram);
        }

        i += 14;
    }

    // Handle remaining bytes
    while i + 3 <= len {
        let trigram = [bytes[i], bytes[i + 1], bytes[i + 2]];
        trigrams.insert(trigram);
        i += 1;
    }

    trigrams
}

/// Batch trigram comparison using SIMD.
///
/// Compares query trigrams against document trigrams for Jaccard scoring.
/// Returns intersection count.
#[must_use]
#[allow(clippy::cast_possible_truncation)]
pub fn count_matching_trigrams_simd(
    query_trigrams: &[[u8; 3]],
    doc_trigrams: &HashSet<[u8; 3]>,
) -> usize {
    // For small sets, scalar is fast enough
    if query_trigrams.len() < 16 {
        return query_trigrams
            .iter()
            .filter(|t| doc_trigrams.contains(*t))
            .count();
    }

    #[cfg(target_arch = "x86_64")]
    {
        if is_x86_feature_detected!("avx2") {
            return count_matching_avx2(query_trigrams, doc_trigrams);
        }
    }

    // Scalar fallback
    query_trigrams
        .iter()
        .filter(|t| doc_trigrams.contains(*t))
        .count()
}

/// Chunked trigram matching (x86_64).
///
/// Note: Despite the name, this function uses scalar HashSet lookups in
/// 8-element chunks. No AVX2 SIMD instructions are used. The chunking
/// pattern was intended to enable future SIMD comparison but is currently
/// equivalent to the scalar fallback.
#[cfg(target_arch = "x86_64")]
fn count_matching_avx2(query_trigrams: &[[u8; 3]], doc_trigrams: &HashSet<[u8; 3]>) -> usize {
    let mut count = 0;

    for chunk in query_trigrams.chunks(8) {
        for trigram in chunk {
            if doc_trigrams.contains(trigram) {
                count += 1;
            }
        }
    }

    count
}