#[cfg(feature = "gpu")]
use crate::gpu::GpuAccelerator;
#[cfg(feature = "gpu")]
use roaring::RoaringBitmap;
#[cfg(feature = "gpu")]
use std::collections::HashSet;
#[cfg(feature = "gpu")]
use std::sync::Arc;
#[cfg(feature = "gpu")]
pub struct GpuTrigramAccelerator {
accelerator: Arc<GpuAccelerator>,
}
#[cfg(feature = "gpu")]
impl GpuTrigramAccelerator {
pub fn new() -> Result<Self, String> {
let accelerator = GpuAccelerator::global().ok_or("GPU not available")?;
Ok(Self { accelerator })
}
#[must_use]
pub fn is_available() -> bool {
GpuAccelerator::is_available()
}
#[must_use]
pub fn batch_search(
&self,
patterns: &[&str],
inverted_index: &std::collections::HashMap<[u8; 3], RoaringBitmap>,
) -> Vec<RoaringBitmap> {
patterns
.iter()
.map(|pattern| Self::search_single(pattern, inverted_index))
.collect()
}
fn search_single(
pattern: &str,
inverted_index: &std::collections::HashMap<[u8; 3], RoaringBitmap>,
) -> RoaringBitmap {
let trigrams = Self::extract_trigrams_cpu(pattern);
if trigrams.is_empty() {
return RoaringBitmap::new();
}
let mut result: Option<RoaringBitmap> = None;
for trigram in &trigrams {
if let Some(bitmap) = inverted_index.get(trigram) {
result = Some(match result {
Some(r) => r & bitmap,
None => bitmap.clone(),
});
} else {
return RoaringBitmap::new();
}
}
result.unwrap_or_default()
}
#[must_use]
pub fn batch_extract_trigrams(&self, documents: &[&str]) -> Vec<HashSet<[u8; 3]>> {
documents
.iter()
.map(|doc| Self::extract_trigrams_cpu(doc))
.collect()
}
fn extract_trigrams_cpu(text: &str) -> HashSet<[u8; 3]> {
let bytes = text.as_bytes();
if bytes.len() < 3 {
return HashSet::new();
}
let mut trigrams = HashSet::with_capacity(bytes.len().saturating_sub(2));
for window in bytes.windows(3) {
trigrams.insert([window[0], window[1], window[2]]);
}
trigrams
}
#[must_use]
pub fn accelerator(&self) -> &GpuAccelerator {
&self.accelerator
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum TrigramComputeBackend {
#[default]
CpuSimd,
#[cfg(feature = "gpu")]
Gpu,
}
impl TrigramComputeBackend {
#[must_use]
pub fn auto_select(doc_count: usize, pattern_count: usize) -> Self {
#[cfg(not(feature = "gpu"))]
let _ = (doc_count, pattern_count);
#[cfg(feature = "gpu")]
{
if doc_count > 500_000 || (doc_count > 100_000 && pattern_count > 10) {
if crate::gpu::ComputeBackend::gpu_available() {
return Self::Gpu;
}
}
}
Self::CpuSimd
}
#[must_use]
pub const fn name(&self) -> &'static str {
match self {
Self::CpuSimd => "CPU SIMD",
#[cfg(feature = "gpu")]
Self::Gpu => "GPU (wgpu)",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_backend_auto_select_small() {
let backend = TrigramComputeBackend::auto_select(10_000, 1);
assert_eq!(backend, TrigramComputeBackend::CpuSimd);
}
#[test]
fn test_backend_auto_select_medium() {
let backend = TrigramComputeBackend::auto_select(100_000, 5);
assert_eq!(backend, TrigramComputeBackend::CpuSimd);
}
#[test]
fn test_backend_name() {
assert_eq!(TrigramComputeBackend::CpuSimd.name(), "CPU SIMD");
}
}
#[cfg(all(test, feature = "gpu"))]
mod gpu_tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_gpu_trigram_accelerator_creation() {
let result = GpuTrigramAccelerator::new();
if result.is_ok() {
println!("GPU trigram accelerator created successfully");
} else {
println!("No GPU available: {:?}", result.err());
}
}
#[test]
fn test_gpu_is_available() {
let _ = GpuTrigramAccelerator::is_available();
}
#[test]
fn test_batch_extract_trigrams() {
if let Ok(gpu) = GpuTrigramAccelerator::new() {
let docs = vec!["hello", "world", "test"];
let results = gpu.batch_extract_trigrams(&docs);
assert_eq!(results.len(), 3);
assert!(results[0].contains(b"hel"));
assert!(results[0].contains(b"ell"));
assert!(results[0].contains(b"llo"));
}
}
#[test]
fn test_batch_extract_trigrams_short_text() {
if let Ok(gpu) = GpuTrigramAccelerator::new() {
let docs = vec!["ab", "a", ""];
let results = gpu.batch_extract_trigrams(&docs);
assert_eq!(results.len(), 3);
assert!(results[0].is_empty()); assert!(results[1].is_empty()); assert!(results[2].is_empty()); }
}
#[test]
fn test_batch_search_empty_patterns() {
if let Ok(gpu) = GpuTrigramAccelerator::new() {
let index: HashMap<[u8; 3], RoaringBitmap> = HashMap::new();
let results = gpu.batch_search(&[], &index);
assert!(results.is_empty());
}
}
#[test]
fn test_batch_search_with_matches() {
if let Ok(gpu) = GpuTrigramAccelerator::new() {
let mut index: HashMap<[u8; 3], RoaringBitmap> = HashMap::new();
let mut bitmap = RoaringBitmap::new();
bitmap.insert(0);
bitmap.insert(1);
index.insert([b'h', b'e', b'l'], bitmap.clone());
index.insert([b'e', b'l', b'l'], bitmap.clone());
index.insert([b'l', b'l', b'o'], bitmap);
let patterns = vec!["hello"];
let results = gpu.batch_search(&patterns, &index);
assert_eq!(results.len(), 1);
assert!(results[0].contains(0));
assert!(results[0].contains(1));
}
}
#[test]
fn test_batch_search_no_matches() {
if let Ok(gpu) = GpuTrigramAccelerator::new() {
let index: HashMap<[u8; 3], RoaringBitmap> = HashMap::new();
let patterns = vec!["hello"];
let results = gpu.batch_search(&patterns, &index);
assert_eq!(results.len(), 1);
assert!(results[0].is_empty());
}
}
}