use crate::algorithms::suffix_array::SuffixArray;
use crate::compression::dict_zip::compression_types::{Match, CompressionType};
use crate::compression::dict_zip::matcher::{PatternMatcher, MatcherConfig};
use crate::memory::simd_ops::{SimdMemOps, get_global_simd_ops};
use crate::string::{SimdStringSearch, get_global_simd_search};
use crate::system::cpu_features::{CpuFeatures, get_cpu_features};
use crate::error::Result;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::cmp::Ordering;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SimdPatternTier {
Scalar,
Sse42,
Avx2,
#[cfg(feature = "avx512")]
Avx512,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum ParallelMode {
X1 = 1,
X2 = 2,
X4 = 4,
X8 = 8,
}
impl ParallelMode {
pub fn stream_count(self) -> usize {
self as usize
}
pub fn is_supported(self, cpu_features: &CpuFeatures) -> bool {
match self {
ParallelMode::X1 => true, ParallelMode::X2 => cpu_features.has_sse42,
ParallelMode::X4 => cpu_features.has_avx2,
ParallelMode::X8 => cpu_features.has_avx2 && cpu_features.logical_cores >= 8,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct SimdMatchResult {
pub base_match: Match,
pub length: usize,
pub input_position: usize,
pub dict_position: usize,
pub quality: f64,
pub simd_tier: SimdPatternTier,
pub parallel_mode: ParallelMode,
pub simd_accelerated: bool,
pub simd_operations: u32,
pub search_time_ns: u64,
}
impl SimdMatchResult {
pub fn new(
base_match: Match,
input_position: usize,
dict_position: usize,
simd_tier: SimdPatternTier,
parallel_mode: ParallelMode,
simd_accelerated: bool,
) -> Self {
let length = base_match.length();
let quality = Self::calculate_quality(length, simd_accelerated);
Self {
base_match,
length,
input_position,
dict_position,
quality,
simd_tier,
parallel_mode,
simd_accelerated,
simd_operations: 0,
search_time_ns: 0,
}
}
fn calculate_quality(length: usize, simd_accelerated: bool) -> f64 {
let base_quality = 1.0 - (-(length as f64) / 128.0).exp();
if simd_accelerated {
(base_quality * 1.1).min(1.0) } else {
base_quality
}
}
pub fn is_better_than(&self, other: &SimdMatchResult) -> bool {
match self.length.cmp(&other.length) {
Ordering::Greater => true,
Ordering::Equal => {
match (self.simd_accelerated, other.simd_accelerated) {
(true, false) => true,
(false, true) => false,
_ => self.quality > other.quality,
}
}
Ordering::Less => false,
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct SimdPatternConfig {
pub enable_simd: bool,
pub parallel_mode: ParallelMode,
pub min_simd_length: usize,
pub max_single_sse_length: usize,
pub max_cascaded_sse_length: usize,
pub enable_early_termination: bool,
pub early_termination_quality: f64,
pub enable_cache_optimization: bool,
pub max_simd_operations: u32,
pub enable_prefetch: bool,
pub enable_bmi2: bool,
}
impl Default for SimdPatternConfig {
fn default() -> Self {
Self {
enable_simd: true,
parallel_mode: ParallelMode::X1,
min_simd_length: 4,
max_single_sse_length: 16,
max_cascaded_sse_length: 35,
enable_early_termination: true,
early_termination_quality: 0.95,
enable_cache_optimization: true,
max_simd_operations: 1000,
enable_prefetch: true,
enable_bmi2: false, }
}
}
impl SimdPatternConfig {
pub fn high_throughput() -> Self {
Self {
parallel_mode: ParallelMode::X4,
max_simd_operations: 2000,
early_termination_quality: 0.9,
..Default::default()
}
}
pub fn low_latency() -> Self {
Self {
parallel_mode: ParallelMode::X1,
max_simd_operations: 500,
early_termination_quality: 0.98,
enable_cache_optimization: true,
..Default::default()
}
}
pub fn maximum_parallelism() -> Self {
Self {
parallel_mode: ParallelMode::X8,
max_simd_operations: 5000,
early_termination_quality: 0.85,
enable_prefetch: true,
..Default::default()
}
}
}
#[derive(Debug, Clone, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct SimdPatternStats {
pub total_searches: u64,
pub simd_accelerated_searches: u64,
pub successful_matches: u64,
pub total_simd_operations: u64,
pub total_search_time_ns: u64,
pub avg_pattern_length: f64,
pub avg_match_quality: f64,
pub simd_hit_rate: f64,
pub tier_usage: [u64; 4], pub parallel_mode_usage: [u64; 4], }
impl SimdPatternStats {
pub fn success_rate(&self) -> f64 {
if self.total_searches == 0 {
0.0
} else {
self.successful_matches as f64 / self.total_searches as f64
}
}
pub fn avg_search_time_ns(&self) -> f64 {
if self.total_searches == 0 {
0.0
} else {
self.total_search_time_ns as f64 / self.total_searches as f64
}
}
pub fn simd_acceleration_ratio(&self) -> f64 {
if self.total_searches == 0 {
0.0
} else {
self.simd_accelerated_searches as f64 / self.total_searches as f64
}
}
pub fn update_with_result(&mut self, result: &SimdMatchResult) {
self.total_searches += 1;
self.total_simd_operations += result.simd_operations as u64;
self.total_search_time_ns += result.search_time_ns;
if result.simd_accelerated {
self.simd_accelerated_searches += 1;
}
if result.length > 0 {
self.successful_matches += 1;
let total_length = self.avg_pattern_length * (self.successful_matches - 1) as f64 + result.length as f64;
self.avg_pattern_length = total_length / self.successful_matches as f64;
let total_quality = self.avg_match_quality * (self.successful_matches - 1) as f64 + result.quality;
self.avg_match_quality = total_quality / self.successful_matches as f64;
}
let tier_index = match result.simd_tier {
SimdPatternTier::Scalar => 0,
SimdPatternTier::Sse42 => 1,
SimdPatternTier::Avx2 => 2,
#[cfg(feature = "avx512")]
SimdPatternTier::Avx512 => 3,
};
self.tier_usage[tier_index] += 1;
let mode_index = match result.parallel_mode {
ParallelMode::X1 => 0,
ParallelMode::X2 => 1,
ParallelMode::X4 => 2,
ParallelMode::X8 => 3,
};
self.parallel_mode_usage[mode_index] += 1;
self.simd_hit_rate = self.simd_acceleration_ratio();
}
}
pub struct SimdPatternMatcher {
config: SimdPatternConfig,
cpu_features: &'static CpuFeatures,
simd_tier: SimdPatternTier,
simd_ops: &'static SimdMemOps,
simd_search: &'static SimdStringSearch,
suffix_array: Option<Arc<SuffixArray>>,
dictionary_text: Option<Arc<Vec<u8>>>,
stats: SimdPatternStats,
fallback_matcher: Option<Arc<PatternMatcher>>,
}
impl SimdPatternMatcher {
pub fn new() -> Self {
let config = SimdPatternConfig::default();
Self::with_config(config)
}
pub fn with_config(config: SimdPatternConfig) -> Self {
let cpu_features = get_cpu_features();
let simd_tier = Self::select_optimal_tier(&config, cpu_features);
let simd_ops = get_global_simd_ops();
let simd_search = get_global_simd_search();
Self {
config,
cpu_features,
simd_tier,
simd_ops,
simd_search,
suffix_array: None,
dictionary_text: None,
stats: SimdPatternStats::default(),
fallback_matcher: None,
}
}
pub fn with_dictionary(
config: SimdPatternConfig,
suffix_array: Arc<SuffixArray>,
dictionary_text: Arc<Vec<u8>>,
) -> Result<Self> {
let mut matcher = Self::with_config(config.clone());
let fallback_config = MatcherConfig {
enable_simd: false, ..Default::default()
};
let fallback = PatternMatcher::with_config(
suffix_array.clone(),
dictionary_text.clone(),
fallback_config,
);
matcher.suffix_array = Some(suffix_array);
matcher.dictionary_text = Some(dictionary_text);
matcher.fallback_matcher = Some(Arc::new(fallback));
Ok(matcher)
}
fn select_optimal_tier(config: &SimdPatternConfig, features: &CpuFeatures) -> SimdPatternTier {
if !config.enable_simd {
return SimdPatternTier::Scalar;
}
#[cfg(feature = "avx512")]
if features.has_avx512f && features.has_avx512vl && features.has_avx512bw {
return SimdPatternTier::Avx512;
}
if features.has_avx2 && config.enable_bmi2 && features.has_bmi2 {
return SimdPatternTier::Avx2;
}
if features.has_sse41 && features.has_sse42 {
return SimdPatternTier::Sse42;
}
SimdPatternTier::Scalar
}
pub fn simd_tier(&self) -> SimdPatternTier {
self.simd_tier
}
pub fn config(&self) -> &SimdPatternConfig {
&self.config
}
pub fn stats(&self) -> &SimdPatternStats {
&self.stats
}
pub fn reset_stats(&mut self) {
self.stats = SimdPatternStats::default();
}
}
impl SimdPatternMatcher {
pub fn find_pattern_matches(
&mut self,
input: &[u8],
pattern: &[u8],
max_matches: usize,
) -> Result<Vec<SimdMatchResult>> {
let start_time = std::time::Instant::now();
if input.is_empty() || pattern.is_empty() {
return Ok(Vec::new());
}
let mut results = Vec::new();
let pattern_len = pattern.len();
let matches = if pattern_len <= self.config.max_single_sse_length {
self.find_matches_single_sse(input, pattern, max_matches)?
} else if pattern_len <= self.config.max_cascaded_sse_length {
self.find_matches_cascaded_sse(input, pattern, max_matches)?
} else {
self.find_matches_vectorized(input, pattern, max_matches)?
};
let search_time_ns = start_time.elapsed().as_nanos() as u64;
for (input_pos, dict_pos) in matches {
let base_match = self.create_base_match(pattern, dict_pos)?;
let mut result = SimdMatchResult::new(
base_match,
input_pos,
dict_pos,
self.simd_tier,
self.config.parallel_mode,
self.simd_tier != SimdPatternTier::Scalar,
);
result.search_time_ns = search_time_ns / (results.len() + 1) as u64;
result.simd_operations = self.estimate_simd_operations(pattern_len);
self.stats.update_with_result(&result);
results.push(result.clone());
if self.config.enable_early_termination &&
result.quality >= self.config.early_termination_quality {
break;
}
}
Ok(results)
}
fn find_matches_single_sse(
&self,
input: &[u8],
pattern: &[u8],
max_matches: usize,
) -> Result<Vec<(usize, usize)>> {
match self.simd_tier {
SimdPatternTier::Sse42 | SimdPatternTier::Avx2 => {
self.sse42_single_pattern_search(input, pattern, max_matches)
}
#[cfg(feature = "avx512")]
SimdPatternTier::Avx512 => {
self.avx512_single_pattern_search(input, pattern, max_matches)
}
SimdPatternTier::Scalar => {
self.scalar_pattern_search(input, pattern, max_matches)
}
}
}
fn find_matches_cascaded_sse(
&self,
input: &[u8],
pattern: &[u8],
max_matches: usize,
) -> Result<Vec<(usize, usize)>> {
match self.simd_tier {
SimdPatternTier::Sse42 | SimdPatternTier::Avx2 => {
self.sse42_cascaded_pattern_search(input, pattern, max_matches)
}
#[cfg(feature = "avx512")]
SimdPatternTier::Avx512 => {
self.avx512_cascaded_pattern_search(input, pattern, max_matches)
}
SimdPatternTier::Scalar => {
self.scalar_pattern_search(input, pattern, max_matches)
}
}
}
fn find_matches_vectorized(
&self,
input: &[u8],
pattern: &[u8],
max_matches: usize,
) -> Result<Vec<(usize, usize)>> {
match self.simd_tier {
SimdPatternTier::Avx2 => {
self.avx2_vectorized_pattern_search(input, pattern, max_matches)
}
#[cfg(feature = "avx512")]
SimdPatternTier::Avx512 => {
self.avx512_vectorized_pattern_search(input, pattern, max_matches)
}
SimdPatternTier::Sse42 | SimdPatternTier::Scalar => {
self.fallback_pattern_search(input, pattern, max_matches)
}
}
}
#[cfg(target_arch = "x86_64")]
fn sse42_single_pattern_search(
&self,
input: &[u8],
pattern: &[u8],
max_matches: usize,
) -> Result<Vec<(usize, usize)>> {
let mut matches = Vec::new();
if pattern.is_empty() || input.len() < pattern.len() {
return Ok(matches);
}
let mut search_pos = 0;
while search_pos <= input.len() - pattern.len() && matches.len() < max_matches {
if let Some(found_pos) = self.simd_search.sse42_strstr(&input[search_pos..], pattern) {
let absolute_pos = search_pos + found_pos;
matches.push((absolute_pos, absolute_pos)); search_pos = absolute_pos + 1;
} else {
break;
}
}
Ok(matches)
}
#[cfg(target_arch = "x86_64")]
fn sse42_cascaded_pattern_search(
&self,
input: &[u8],
pattern: &[u8],
max_matches: usize,
) -> Result<Vec<(usize, usize)>> {
let mut matches = Vec::new();
if pattern.is_empty() || input.len() < pattern.len() {
return Ok(matches);
}
let search_pattern = if pattern.len() > 16 {
&pattern[..16]
} else {
pattern
};
let mut search_pos = 0;
while search_pos <= input.len() - pattern.len() && matches.len() < max_matches {
if let Some(found_pos) = self.simd_search.sse42_strstr(&input[search_pos..], search_pattern) {
let absolute_pos = search_pos + found_pos;
if absolute_pos + pattern.len() <= input.len() {
let candidate = &input[absolute_pos..absolute_pos + pattern.len()];
if candidate == pattern {
matches.push((absolute_pos, absolute_pos));
}
}
search_pos = absolute_pos + 1;
} else {
break;
}
}
Ok(matches)
}
#[cfg(target_arch = "x86_64")]
fn avx2_vectorized_pattern_search(
&self,
input: &[u8],
pattern: &[u8],
max_matches: usize,
) -> Result<Vec<(usize, usize)>> {
let mut matches = Vec::new();
if pattern.is_empty() || input.len() < pattern.len() {
return Ok(matches);
}
let first_char = pattern[0];
let mut search_pos = 0;
while search_pos <= input.len() - pattern.len() && matches.len() < max_matches {
if let Some(found_pos) = self.simd_ops.find_byte(&input[search_pos..], first_char) {
let absolute_pos = search_pos + found_pos;
if absolute_pos + pattern.len() <= input.len() {
let candidate = &input[absolute_pos..absolute_pos + pattern.len()];
if self.simd_ops.compare(candidate, pattern) == 0 {
matches.push((absolute_pos, absolute_pos));
}
}
search_pos = absolute_pos + 1;
} else {
break;
}
}
Ok(matches)
}
#[cfg(all(feature = "avx512", target_arch = "x86_64"))]
fn avx512_single_pattern_search(
&self,
input: &[u8],
pattern: &[u8],
max_matches: usize,
) -> Result<Vec<(usize, usize)>> {
self.sse42_single_pattern_search(input, pattern, max_matches)
}
#[cfg(all(feature = "avx512", target_arch = "x86_64"))]
fn avx512_cascaded_pattern_search(
&self,
input: &[u8],
pattern: &[u8],
max_matches: usize,
) -> Result<Vec<(usize, usize)>> {
self.sse42_cascaded_pattern_search(input, pattern, max_matches)
}
#[cfg(all(feature = "avx512", target_arch = "x86_64"))]
fn avx512_vectorized_pattern_search(
&self,
input: &[u8],
pattern: &[u8],
max_matches: usize,
) -> Result<Vec<(usize, usize)>> {
self.avx2_vectorized_pattern_search(input, pattern, max_matches)
}
#[cfg(not(target_arch = "x86_64"))]
fn sse42_single_pattern_search(
&self,
input: &[u8],
pattern: &[u8],
max_matches: usize,
) -> Result<Vec<(usize, usize)>> {
self.scalar_pattern_search(input, pattern, max_matches)
}
#[cfg(not(target_arch = "x86_64"))]
fn sse42_cascaded_pattern_search(
&self,
input: &[u8],
pattern: &[u8],
max_matches: usize,
) -> Result<Vec<(usize, usize)>> {
self.scalar_pattern_search(input, pattern, max_matches)
}
#[cfg(not(target_arch = "x86_64"))]
fn avx2_vectorized_pattern_search(
&self,
input: &[u8],
pattern: &[u8],
max_matches: usize,
) -> Result<Vec<(usize, usize)>> {
self.scalar_pattern_search(input, pattern, max_matches)
}
#[cfg(not(all(feature = "avx512", target_arch = "x86_64")))]
fn avx512_single_pattern_search(
&self,
input: &[u8],
pattern: &[u8],
max_matches: usize,
) -> Result<Vec<(usize, usize)>> {
self.sse42_single_pattern_search(input, pattern, max_matches)
}
#[cfg(not(all(feature = "avx512", target_arch = "x86_64")))]
fn avx512_cascaded_pattern_search(
&self,
input: &[u8],
pattern: &[u8],
max_matches: usize,
) -> Result<Vec<(usize, usize)>> {
self.sse42_cascaded_pattern_search(input, pattern, max_matches)
}
#[cfg(not(all(feature = "avx512", target_arch = "x86_64")))]
fn avx512_vectorized_pattern_search(
&self,
input: &[u8],
pattern: &[u8],
max_matches: usize,
) -> Result<Vec<(usize, usize)>> {
self.avx2_vectorized_pattern_search(input, pattern, max_matches)
}
fn scalar_pattern_search(
&self,
input: &[u8],
pattern: &[u8],
max_matches: usize,
) -> Result<Vec<(usize, usize)>> {
let mut matches = Vec::new();
if pattern.is_empty() || input.len() < pattern.len() {
return Ok(matches);
}
for i in 0..=input.len() - pattern.len() {
if matches.len() >= max_matches {
break;
}
if &input[i..i + pattern.len()] == pattern {
matches.push((i, i));
}
}
Ok(matches)
}
fn fallback_pattern_search(
&self,
input: &[u8],
pattern: &[u8],
max_matches: usize,
) -> Result<Vec<(usize, usize)>> {
if let Some(ref fallback) = self.fallback_matcher {
let matches = fallback.find_all_matches(pattern, max_matches)?;
Ok(matches.into_iter()
.map(|m| (m.dict_position, m.dict_position))
.collect())
} else {
self.scalar_pattern_search(input, pattern, max_matches)
}
}
fn create_base_match(&self, pattern: &[u8], dict_position: usize) -> Result<Match> {
let length = pattern.len();
if length <= 32 {
Match::literal(length as u8)
} else {
Match::global(dict_position as u32, length as u16)
}
}
fn estimate_simd_operations(&self, pattern_len: usize) -> u32 {
(match self.simd_tier {
SimdPatternTier::Scalar => 0,
SimdPatternTier::Sse42 => {
if pattern_len <= 16 {
1
} else if pattern_len <= 35 {
(pattern_len + 15) / 16
} else {
pattern_len / 16 + 1
}
}
SimdPatternTier::Avx2 => {
if pattern_len <= 32 {
1
} else {
(pattern_len + 31) / 32
}
}
#[cfg(feature = "avx512")]
SimdPatternTier::Avx512 => {
if pattern_len <= 64 {
1
} else {
(pattern_len + 63) / 64
}
}
}) as u32
}
pub fn find_parallel_pattern_matches(
&mut self,
input: &[u8],
patterns: &[&[u8]],
max_matches_per_pattern: usize,
parallel_mode: ParallelMode,
) -> Result<Vec<Vec<SimdMatchResult>>> {
if !parallel_mode.is_supported(self.cpu_features) {
return self.find_parallel_pattern_matches(input, patterns, max_matches_per_pattern, ParallelMode::X1);
}
match parallel_mode {
ParallelMode::X1 => self.find_patterns_x1(input, patterns, max_matches_per_pattern),
ParallelMode::X2 => self.find_patterns_x2(input, patterns, max_matches_per_pattern),
ParallelMode::X4 => self.find_patterns_x4(input, patterns, max_matches_per_pattern),
ParallelMode::X8 => self.find_patterns_x8(input, patterns, max_matches_per_pattern),
}
}
fn find_patterns_x1(
&mut self,
input: &[u8],
patterns: &[&[u8]],
max_matches_per_pattern: usize,
) -> Result<Vec<Vec<SimdMatchResult>>> {
let mut results = Vec::with_capacity(patterns.len());
for pattern in patterns {
let matches = self.find_pattern_matches(input, pattern, max_matches_per_pattern)?;
results.push(matches);
}
Ok(results)
}
fn find_patterns_x2(
&mut self,
input: &[u8],
patterns: &[&[u8]],
max_matches_per_pattern: usize,
) -> Result<Vec<Vec<SimdMatchResult>>> {
if patterns.len() <= 1 {
return self.find_patterns_x1(input, patterns, max_matches_per_pattern);
}
let mid = patterns.len() / 2;
let (left_patterns, right_patterns) = patterns.split_at(mid);
let left_results = self.find_patterns_x1(input, left_patterns, max_matches_per_pattern)?;
let right_results = self.find_patterns_x1(input, right_patterns, max_matches_per_pattern)?;
let mut combined_results = Vec::with_capacity(patterns.len());
combined_results.extend(left_results);
combined_results.extend(right_results);
Ok(combined_results)
}
fn find_patterns_x4(
&mut self,
input: &[u8],
patterns: &[&[u8]],
max_matches_per_pattern: usize,
) -> Result<Vec<Vec<SimdMatchResult>>> {
if patterns.len() <= 2 {
return self.find_patterns_x2(input, patterns, max_matches_per_pattern);
}
let quarter_size = patterns.len() / 4;
let mut results = Vec::with_capacity(patterns.len());
for chunk_start in (0..patterns.len()).step_by(quarter_size.max(1)) {
let chunk_end = (chunk_start + quarter_size).min(patterns.len());
let chunk = &patterns[chunk_start..chunk_end];
let chunk_results = self.find_patterns_x1(input, chunk, max_matches_per_pattern)?;
results.extend(chunk_results);
}
Ok(results)
}
fn find_patterns_x8(
&mut self,
input: &[u8],
patterns: &[&[u8]],
max_matches_per_pattern: usize,
) -> Result<Vec<Vec<SimdMatchResult>>> {
if patterns.len() <= 4 {
return self.find_patterns_x4(input, patterns, max_matches_per_pattern);
}
let eighth_size = patterns.len() / 8;
let mut results = Vec::with_capacity(patterns.len());
for chunk_start in (0..patterns.len()).step_by(eighth_size.max(1)) {
let chunk_end = (chunk_start + eighth_size).min(patterns.len());
let chunk = &patterns[chunk_start..chunk_end];
let chunk_results = self.find_patterns_x1(input, chunk, max_matches_per_pattern)?;
results.extend(chunk_results);
}
Ok(results)
}
pub fn find_compression_aware_matches(
&mut self,
input: &[u8],
pattern: &[u8],
compression_type: CompressionType,
max_matches: usize,
) -> Result<Vec<SimdMatchResult>> {
let start_time = std::time::Instant::now();
let matches = match compression_type {
CompressionType::Literal => {
self.find_literal_matches(input, pattern, max_matches)?
}
CompressionType::Global => {
self.find_global_matches(input, pattern, max_matches)?
}
CompressionType::RLE => {
self.find_rle_matches(input, pattern, max_matches)?
}
CompressionType::NearShort | CompressionType::Far1Short | CompressionType::Far2Short => {
self.find_distance_optimized_matches(input, pattern, max_matches, true)?
}
CompressionType::Far2Long | CompressionType::Far3Long => {
self.find_distance_optimized_matches(input, pattern, max_matches, false)?
}
};
let search_time_ns = start_time.elapsed().as_nanos() as u64;
let mut results = Vec::new();
for (input_pos, dict_pos) in matches {
let base_match = self.create_compression_aware_match(pattern, dict_pos, compression_type)?;
let mut result = SimdMatchResult::new(
base_match,
input_pos,
dict_pos,
self.simd_tier,
self.config.parallel_mode,
self.simd_tier != SimdPatternTier::Scalar,
);
result.search_time_ns = search_time_ns / (results.len() + 1) as u64;
result.simd_operations = self.estimate_simd_operations(pattern.len());
self.stats.update_with_result(&result);
results.push(result);
}
Ok(results)
}
fn find_literal_matches(
&self,
input: &[u8],
pattern: &[u8],
max_matches: usize,
) -> Result<Vec<(usize, usize)>> {
self.find_matches_single_sse(input, pattern, max_matches)
}
fn find_global_matches(
&self,
input: &[u8],
pattern: &[u8],
max_matches: usize,
) -> Result<Vec<(usize, usize)>> {
if let Some(ref fallback) = self.fallback_matcher {
let matches = fallback.find_all_matches(pattern, max_matches)?;
Ok(matches.into_iter()
.map(|m| (m.input_position, m.dict_position))
.collect())
} else {
self.find_matches_single_sse(input, pattern, max_matches)
}
}
fn find_rle_matches(
&self,
input: &[u8],
pattern: &[u8],
max_matches: usize,
) -> Result<Vec<(usize, usize)>> {
if pattern.is_empty() {
return Ok(Vec::new());
}
let mut matches = Vec::new();
if pattern.len() >= 4 && pattern.iter().all(|&b| b == pattern[0]) {
let target_byte = pattern[0];
let run_length = pattern.len();
let mut pos = 0;
while pos <= input.len() - run_length && matches.len() < max_matches {
if let Some(found_pos) = self.simd_ops.find_byte(&input[pos..], target_byte) {
let absolute_pos = pos + found_pos;
let mut run_len = 0;
for i in absolute_pos..input.len() {
if input[i] == target_byte {
run_len += 1;
} else {
break;
}
}
if run_len >= run_length {
matches.push((absolute_pos, absolute_pos));
}
pos = absolute_pos + 1;
} else {
break;
}
}
} else {
return self.find_matches_single_sse(input, pattern, max_matches);
}
Ok(matches)
}
fn find_distance_optimized_matches(
&self,
input: &[u8],
pattern: &[u8],
max_matches: usize,
is_short: bool,
) -> Result<Vec<(usize, usize)>> {
if is_short && pattern.len() <= self.config.max_single_sse_length {
self.find_matches_single_sse(input, pattern, max_matches)
} else {
self.find_matches_vectorized(input, pattern, max_matches)
}
}
fn create_compression_aware_match(
&self,
pattern: &[u8],
dict_position: usize,
compression_type: CompressionType,
) -> Result<Match> {
let length = pattern.len();
match compression_type {
CompressionType::Literal => {
Match::literal(length.min(255) as u8)
}
CompressionType::Global => {
Match::global(dict_position as u32, length.min(65535) as u16)
}
CompressionType::RLE => {
if length <= 255 {
Match::literal(length as u8) } else {
Match::global(dict_position as u32, length.min(65535) as u16)
}
}
CompressionType::NearShort => {
let distance = dict_position.min(255) as u8;
Match::near_short(distance, length.min(255) as u8)
}
CompressionType::Far1Short => {
let distance = dict_position.min(65535) as u16;
Match::far1_short(distance, length.min(255) as u8)
}
CompressionType::Far2Short => {
let distance = dict_position as u32;
Match::far2_short(distance, length.min(255) as u8)
}
CompressionType::Far2Long => {
let distance = dict_position as u32;
Match::far2_long(distance.try_into().unwrap_or(u16::MAX), length.min(65535) as u16)
}
CompressionType::Far3Long => {
let distance = dict_position as u32;
Match::far3_long(distance, length as u32)
}
}
}
}
impl Default for SimdPatternMatcher {
fn default() -> Self {
Self::new()
}
}
static GLOBAL_SIMD_PATTERN_MATCHER: std::sync::OnceLock<SimdPatternMatcher> = std::sync::OnceLock::new();
pub fn get_global_simd_pattern_matcher() -> &'static SimdPatternMatcher {
GLOBAL_SIMD_PATTERN_MATCHER.get_or_init(|| SimdPatternMatcher::new())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simd_pattern_matcher_creation() {
let matcher = SimdPatternMatcher::new();
assert!(matches!(matcher.simd_tier(),
SimdPatternTier::Scalar |
SimdPatternTier::Sse42 |
SimdPatternTier::Avx2
));
}
#[test]
fn test_parallel_mode_support() {
let cpu_features = get_cpu_features();
assert!(ParallelMode::X1.is_supported(cpu_features));
println!("X2 supported: {}", ParallelMode::X2.is_supported(cpu_features));
println!("X4 supported: {}", ParallelMode::X4.is_supported(cpu_features));
println!("X8 supported: {}", ParallelMode::X8.is_supported(cpu_features));
}
#[test]
fn test_config_presets() {
let high_throughput = SimdPatternConfig::high_throughput();
assert_eq!(high_throughput.parallel_mode, ParallelMode::X4);
let low_latency = SimdPatternConfig::low_latency();
assert_eq!(low_latency.parallel_mode, ParallelMode::X1);
let max_parallel = SimdPatternConfig::maximum_parallelism();
assert_eq!(max_parallel.parallel_mode, ParallelMode::X8);
}
#[test]
fn test_simd_match_result() {
let base_match = Match::literal(10).unwrap();
let result = SimdMatchResult::new(
base_match,
0,
100,
SimdPatternTier::Avx2,
ParallelMode::X2,
true,
);
assert_eq!(result.length, 10);
assert_eq!(result.input_position, 0);
assert_eq!(result.dict_position, 100);
assert_eq!(result.simd_tier, SimdPatternTier::Avx2);
assert_eq!(result.parallel_mode, ParallelMode::X2);
assert!(result.simd_accelerated);
assert!(result.quality > 0.0);
}
#[test]
fn test_simd_match_comparison() {
let base_match1 = Match::literal(8).unwrap();
let result1 = SimdMatchResult::new(
base_match1,
0,
100,
SimdPatternTier::Sse42,
ParallelMode::X1,
false,
);
let base_match2 = Match::literal(10).unwrap();
let result2 = SimdMatchResult::new(
base_match2,
0,
200,
SimdPatternTier::Avx2,
ParallelMode::X2,
true,
);
assert!(result2.is_better_than(&result1));
assert!(!result1.is_better_than(&result2));
}
#[test]
fn test_stats_update() {
let mut stats = SimdPatternStats::default();
let base_match = Match::literal(12).unwrap();
let mut result = SimdMatchResult::new(
base_match,
0,
100,
SimdPatternTier::Avx2,
ParallelMode::X1,
true,
);
result.search_time_ns = 1000;
result.simd_operations = 5;
stats.update_with_result(&result);
assert_eq!(stats.total_searches, 1);
assert_eq!(stats.simd_accelerated_searches, 1);
assert_eq!(stats.successful_matches, 1);
assert_eq!(stats.total_simd_operations, 5);
assert_eq!(stats.total_search_time_ns, 1000);
assert_eq!(stats.success_rate(), 1.0);
assert_eq!(stats.simd_acceleration_ratio(), 1.0);
}
#[test]
fn test_global_instance() {
let matcher1 = get_global_simd_pattern_matcher();
let matcher2 = get_global_simd_pattern_matcher();
assert_eq!(matcher1.simd_tier(), matcher2.simd_tier());
}
#[test]
fn test_tier_selection() {
let mut config = SimdPatternConfig::default();
let cpu_features = get_cpu_features();
config.enable_simd = false;
let tier = SimdPatternMatcher::select_optimal_tier(&config, cpu_features);
assert_eq!(tier, SimdPatternTier::Scalar);
config.enable_simd = true;
let tier = SimdPatternMatcher::select_optimal_tier(&config, cpu_features);
assert!(matches!(tier,
SimdPatternTier::Scalar |
SimdPatternTier::Sse42 |
SimdPatternTier::Avx2
));
}
#[test]
fn test_pattern_matching_basic() {
let mut matcher = SimdPatternMatcher::new();
let input = b"hello world hello universe hello multiverse";
let pattern = b"hello";
let results = matcher.find_pattern_matches(input, pattern, 10).unwrap();
assert!(results.len() >= 3);
for result in &results {
assert_eq!(result.length, 5);
assert!(result.input_position < input.len());
assert!(result.quality > 0.0);
}
}
#[test]
fn test_multi_level_search_strategy() {
let mut matcher = SimdPatternMatcher::new();
let input = b"abcdefghijklmnopqrstuvwxyz".repeat(10);
let short_pattern = b"abcdef";
let short_results = matcher.find_pattern_matches(&input, short_pattern, 5).unwrap();
assert!(!short_results.is_empty());
let medium_pattern = b"abcdefghijklmnopqrstuvwxyz";
let medium_results = matcher.find_pattern_matches(&input, medium_pattern, 5).unwrap();
assert!(!medium_results.is_empty());
let large_pattern = b"abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz";
let large_results = matcher.find_pattern_matches(&input, large_pattern, 5).unwrap();
println!("Large pattern results: {}", large_results.len());
}
#[test]
fn test_parallel_processing_variants() {
let mut matcher = SimdPatternMatcher::new();
let input = b"pattern1 pattern2 pattern3 pattern4 pattern1 pattern2 pattern3 pattern4";
let patterns = vec![b"pattern1".as_slice(), b"pattern2".as_slice(), b"pattern3".as_slice(), b"pattern4".as_slice()];
let x1_results = matcher.find_parallel_pattern_matches(input, &patterns, 5, ParallelMode::X1).unwrap();
assert_eq!(x1_results.len(), 4);
for pattern_results in &x1_results {
assert!(!pattern_results.is_empty());
}
let x2_results = matcher.find_parallel_pattern_matches(input, &patterns, 5, ParallelMode::X2).unwrap();
assert_eq!(x2_results.len(), 4);
let x4_results = matcher.find_parallel_pattern_matches(input, &patterns, 5, ParallelMode::X4).unwrap();
assert_eq!(x4_results.len(), 4);
let x8_results = matcher.find_parallel_pattern_matches(input, &patterns, 5, ParallelMode::X8).unwrap();
assert_eq!(x8_results.len(), 4);
}
#[test]
fn test_compression_aware_matching() {
let mut matcher = SimdPatternMatcher::new();
let input = b"literal data with repeated bytes aaaaaaaaaa and global references";
let literal_pattern = b"literal";
let literal_results = matcher.find_compression_aware_matches(
input, literal_pattern, CompressionType::Literal, 5
).unwrap();
assert!(!literal_results.is_empty());
let rle_pattern = b"aaaaaaaaaa";
let rle_results = matcher.find_compression_aware_matches(
input, rle_pattern, CompressionType::RLE, 5
).unwrap();
assert!(!rle_results.is_empty());
let global_pattern = b"global";
let global_results = matcher.find_compression_aware_matches(
input, global_pattern, CompressionType::Global, 5
).unwrap();
println!("Global results: {}", global_results.len());
}
#[test]
fn test_rle_optimization() {
let mut matcher = SimdPatternMatcher::new();
let input = b"before aaaaaaaaaa middle bbbbbbbbbb after";
let rle_pattern = b"aaaaaaaaaa";
let results = matcher.find_rle_matches(input, rle_pattern, 5).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].0, 7);
let non_rle_pattern = b"before";
let non_rle_results = matcher.find_rle_matches(input, non_rle_pattern, 5).unwrap();
assert!(!non_rle_results.is_empty());
}
#[test]
fn test_simd_match_quality() {
let base_match = Match::literal(15).unwrap();
let simd_result = SimdMatchResult::new(
base_match.clone(),
0,
100,
SimdPatternTier::Avx2,
ParallelMode::X1,
true, );
let scalar_result = SimdMatchResult::new(
base_match,
0,
100,
SimdPatternTier::Scalar,
ParallelMode::X1,
false, );
assert!(simd_result.quality > scalar_result.quality);
assert!(simd_result.is_better_than(&scalar_result));
}
#[test]
fn test_early_termination() {
let config = SimdPatternConfig {
enable_early_termination: true,
early_termination_quality: 0.5,
..Default::default()
};
let mut matcher = SimdPatternMatcher::with_config(config);
let input = b"test pattern test pattern test pattern test pattern";
let pattern = b"test";
let results = matcher.find_pattern_matches(input, pattern, 100).unwrap();
assert!(!results.is_empty());
println!("Early termination test found {} matches", results.len());
}
#[test]
fn test_performance_statistics() {
let mut matcher = SimdPatternMatcher::new();
let input = b"statistics test data with multiple patterns and matches";
let pattern = b"test";
for _ in 0..5 {
let _ = matcher.find_pattern_matches(input, pattern, 10).unwrap();
}
let stats = matcher.stats();
assert_eq!(stats.total_searches, 5);
assert!(stats.avg_search_time_ns() > 0.0);
matcher.reset_stats();
let reset_stats = matcher.stats();
assert_eq!(reset_stats.total_searches, 0);
}
#[test]
fn test_config_presets_detailed() {
let high_throughput = SimdPatternConfig::high_throughput();
assert_eq!(high_throughput.parallel_mode, ParallelMode::X4);
assert_eq!(high_throughput.max_simd_operations, 2000);
assert_eq!(high_throughput.early_termination_quality, 0.9);
let low_latency = SimdPatternConfig::low_latency();
assert_eq!(low_latency.parallel_mode, ParallelMode::X1);
assert_eq!(low_latency.max_simd_operations, 500);
assert_eq!(low_latency.early_termination_quality, 0.98);
let max_parallel = SimdPatternConfig::maximum_parallelism();
assert_eq!(max_parallel.parallel_mode, ParallelMode::X8);
assert_eq!(max_parallel.max_simd_operations, 5000);
assert_eq!(max_parallel.early_termination_quality, 0.85);
}
#[test]
fn test_empty_inputs() {
let mut matcher = SimdPatternMatcher::new();
let empty_results = matcher.find_pattern_matches(b"", b"pattern", 10).unwrap();
assert!(empty_results.is_empty());
let empty_pattern_results = matcher.find_pattern_matches(b"input", b"", 10).unwrap();
assert!(empty_pattern_results.is_empty());
let both_empty_results = matcher.find_pattern_matches(b"", b"", 10).unwrap();
assert!(both_empty_results.is_empty());
}
#[test]
fn test_large_pattern_fallback() {
let mut matcher = SimdPatternMatcher::new();
let input = b"This is a test string with some content to search through for large patterns";
let large_pattern = b"This is a test string with some content to search through for large patterns that exceed the cascaded SSE limit";
let results = matcher.find_pattern_matches(input, large_pattern, 5).unwrap();
println!("Large pattern fallback test completed with {} results", results.len());
}
#[test]
fn test_simd_operations_estimation() {
let matcher = SimdPatternMatcher::new();
match matcher.simd_tier() {
SimdPatternTier::Scalar => {
assert_eq!(matcher.estimate_simd_operations(10), 0);
}
SimdPatternTier::Sse42 => {
assert_eq!(matcher.estimate_simd_operations(10), 1); assert_eq!(matcher.estimate_simd_operations(25), 2); assert_eq!(matcher.estimate_simd_operations(50), 4); }
SimdPatternTier::Avx2 => {
assert_eq!(matcher.estimate_simd_operations(20), 1); assert_eq!(matcher.estimate_simd_operations(50), 2); }
#[cfg(feature = "avx512")]
SimdPatternTier::Avx512 => {
assert_eq!(matcher.estimate_simd_operations(40), 1); assert_eq!(matcher.estimate_simd_operations(100), 2); }
}
}
#[test]
fn test_compression_type_match_creation() {
let matcher = SimdPatternMatcher::new();
let pattern = b"test";
let test_cases = [
(CompressionType::Literal, 0), (CompressionType::RLE, 0), ];
for (compression_type, dict_pos) in &test_cases {
let result = matcher.create_compression_aware_match(pattern, *dict_pos, *compression_type);
match result {
Ok(_) => {
println!("Successfully created match for compression type: {:?}", compression_type);
}
Err(e) => {
println!("Error creating match for compression type: {:?} with dict_pos: {}: {:?}", compression_type, dict_pos, e);
panic!("Failed to create match for compression type: {:?} with dict_pos: {} - Error: {:?}", compression_type, dict_pos, e);
}
}
}
}
}