win-auto-utils 0.2.1

Universal Windows automation utilities with memory, window, input, and color operations
//! AVX2 and AVX-512 accelerated pattern verification

use crate::memory_aobscan::pattern::Pattern;

/// AVX2 accelerated pattern verification (processes 32 bytes at a time)
///
/// # Safety
/// Requires AVX2 CPU support (checked before calling)
#[target_feature(enable = "avx2")]
#[inline]
pub unsafe fn verify_pattern_avx2(buffer: &[u8], offset: usize, pattern: &Pattern) -> bool {
    use std::arch::x86_64::*;
    
    let len = pattern.bytes.len();
    
    // Bounds check
    if offset + len > buffer.len() {
        return false;
    }
    
    let buf_ptr = buffer.as_ptr().add(offset);
    let pat_ptr = pattern.bytes.as_ptr();
    let mask_ptr = pattern.mask_bytes.as_ptr();
    
    // Process 32 bytes at a time using AVX2
    let mut i = 0;
    while i + 32 <= len {
        // Load data (unaligned load - modern CPUs handle this efficiently)
        let buf_chunk = _mm256_loadu_si256(buf_ptr.add(i) as *const __m256i);
        let pat_chunk = _mm256_loadu_si256(pat_ptr.add(i) as *const __m256i);
        let mask_chunk = _mm256_loadu_si256(mask_ptr.add(i) as *const __m256i);
        
        // Compare: (buf == pat) OR (NOT mask)
        // If byte is wildcard (mask=0x00), NOT mask=0xFF, result is all 1s (match)
        // If byte must match (mask=0xFF), NOT mask=0x00, result depends on comparison
        let cmp = _mm256_cmpeq_epi8(buf_chunk, pat_chunk);
        let not_mask = _mm256_andnot_si256(mask_chunk, _mm256_set1_epi8(-1));
        let result = _mm256_or_si256(cmp, not_mask);
        
        // Check if all bytes match (result should be all 0xFF = -1)
        if _mm256_movemask_epi8(result) != -1 {
            return false;
        }
        
        i += 32;
    }
    
    // Handle remaining bytes with scalar code
    for j in i..len {
        if pattern.mask[j] && buffer[offset + j] != pattern.bytes[j] {
            return false;
        }
    }
    
    true
}

/// AVX-512 accelerated pattern verification (processes 64 bytes at a time)
///
/// This provides ~30-50% speedup over AVX2 for long patterns.
///
/// # Safety
/// Requires AVX-512F CPU support (checked before calling)
#[target_feature(enable = "avx512f")]
#[inline]
pub unsafe fn verify_pattern_avx512(buffer: &[u8], offset: usize, pattern: &Pattern) -> bool {
    use std::arch::x86_64::*;
    
    let len = pattern.bytes.len();
    
    // Bounds check
    if offset + len > buffer.len() {
        return false;
    }
    
    let buf_ptr = buffer.as_ptr().add(offset);
    let pat_ptr = pattern.bytes.as_ptr();
    let mask_ptr = pattern.mask_bytes.as_ptr();
    
    // Process 64 bytes at a time using AVX-512
    let mut i = 0;
    while i + 64 <= len {
        // Load 64 bytes at once
        let buf_chunk = _mm512_loadu_si512(buf_ptr.add(i) as *const __m512i);
        let pat_chunk = _mm512_loadu_si512(pat_ptr.add(i) as *const __m512i);
        let mask_chunk = _mm512_loadu_si512(mask_ptr.add(i) as *const __m512i);
        
        // Compare bytes: returns mask where equal bytes have bit set
        let cmp_mask = _mm512_cmpeq_epi8_mask(buf_chunk, pat_chunk);
        
        // Get mask of required bytes (where mask_bytes is 0xFF)
        let required_mask = _mm512_movepi8_mask(mask_chunk);
        
        // Check: all required bytes must match
        // (cmp_mask & required_mask) == required_mask
        let matched_required = cmp_mask & required_mask;
        
        if matched_required != required_mask {
            return false;
        }
        
        i += 64;
    }
    
    // Handle remaining bytes with AVX2 or scalar
    if i + 32 <= len {
        // Use AVX2 for next 32 bytes
        let buf_chunk = _mm256_loadu_si256(buf_ptr.add(i) as *const __m256i);
        let pat_chunk = _mm256_loadu_si256(pat_ptr.add(i) as *const __m256i);
        let mask_chunk = _mm256_loadu_si256(mask_ptr.add(i) as *const __m256i);
        
        let cmp = _mm256_cmpeq_epi8(buf_chunk, pat_chunk);
        let not_mask = _mm256_andnot_si256(mask_chunk, _mm256_set1_epi8(-1));
        let result = _mm256_or_si256(cmp, not_mask);
        
        if _mm256_movemask_epi8(result) != -1 {
            return false;
        }
        
        i += 32;
    }
    
    // Handle final remaining bytes with scalar code
    for j in i..len {
        if pattern.mask[j] && buffer[offset + j] != pattern.bytes[j] {
            return false;
        }
    }
    
    true
}