use parking_lot::RwLock;
use pgrx::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::OnceLock;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use super::ipc::{
get_shared_memory, BuildIndexRequest, DeleteRequest, InsertRequest, Operation, ResultStatus,
SearchRequest, UpdateIndexRequest, WorkItem, WorkResult,
};
use super::lifecycle::{get_lifecycle_manager, WorkerStatus};
pub use super::ipc::SearchRequest as SearchReq;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EngineWorkerConfig {
pub max_index_memory: usize,
pub max_concurrent_searches: usize,
pub work_queue_size: usize,
pub shutdown_timeout_secs: u64,
pub enable_simd: bool,
pub prefetch_distance: usize,
pub insert_batch_size: usize,
pub enable_query_cache: bool,
pub query_cache_size: usize,
pub query_cache_ttl_secs: u64,
}
impl Default for EngineWorkerConfig {
fn default() -> Self {
Self {
max_index_memory: 4 * 1024 * 1024 * 1024, max_concurrent_searches: 64,
work_queue_size: 1024,
shutdown_timeout_secs: 30,
enable_simd: true,
prefetch_distance: 4,
insert_batch_size: 1000,
enable_query_cache: true,
query_cache_size: 10000,
query_cache_ttl_secs: 60,
}
}
}
static ENGINE_CONFIG: OnceLock<RwLock<EngineWorkerConfig>> = OnceLock::new();
pub fn get_engine_config() -> EngineWorkerConfig {
ENGINE_CONFIG
.get_or_init(|| RwLock::new(EngineWorkerConfig::default()))
.read()
.clone()
}
pub fn set_engine_config(config: EngineWorkerConfig) {
if let Some(cfg) = ENGINE_CONFIG.get() {
*cfg.write() = config;
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub ids: Vec<i64>,
pub distances: Vec<f32>,
pub search_time_us: u64,
pub vectors_scanned: u64,
pub cache_hit: bool,
}
impl SearchResult {
pub fn empty() -> Self {
Self {
ids: Vec::new(),
distances: Vec::new(),
search_time_us: 0,
vectors_scanned: 0,
cache_hit: false,
}
}
pub fn to_bytes(&self) -> Vec<u8> {
bincode::serialize(self).unwrap_or_default()
}
pub fn from_bytes(data: &[u8]) -> Option<Self> {
bincode::deserialize(data).ok()
}
}
#[derive(Debug)]
pub struct CollectionIndex {
pub collection_id: i32,
pub index_type: String,
pub vector_count: u64,
pub dimensions: usize,
pub memory_bytes: usize,
pub last_access: u64,
pub query_count: AtomicU64,
pub loaded: AtomicBool,
}
impl CollectionIndex {
pub fn new(collection_id: i32, index_type: &str, dimensions: usize) -> Self {
Self {
collection_id,
index_type: index_type.to_string(),
vector_count: 0,
dimensions,
memory_bytes: 0,
last_access: current_epoch_secs(),
query_count: AtomicU64::new(0),
loaded: AtomicBool::new(false),
}
}
pub fn record_query(&self) {
self.query_count.fetch_add(1, Ordering::Relaxed);
}
pub fn get_query_count(&self) -> u64 {
self.query_count.load(Ordering::Relaxed)
}
pub fn is_loaded(&self) -> bool {
self.loaded.load(Ordering::SeqCst)
}
}
#[derive(Clone)]
struct CacheEntry {
result: SearchResult,
created_at: u64,
}
struct QueryCache {
entries: RwLock<HashMap<u64, CacheEntry>>,
max_size: usize,
ttl_secs: u64,
}
impl QueryCache {
fn new(max_size: usize, ttl_secs: u64) -> Self {
Self {
entries: RwLock::new(HashMap::new()),
max_size,
ttl_secs,
}
}
fn get(&self, key: u64) -> Option<SearchResult> {
let entries = self.entries.read();
if let Some(entry) = entries.get(&key) {
let now = current_epoch_secs();
if now - entry.created_at < self.ttl_secs {
return Some(entry.result.clone());
}
}
None
}
fn put(&self, key: u64, result: SearchResult) {
let mut entries = self.entries.write();
if entries.len() >= self.max_size {
let now = current_epoch_secs();
entries.retain(|_, v| now - v.created_at < self.ttl_secs);
if entries.len() >= self.max_size {
if let Some(oldest_key) = entries
.iter()
.min_by_key(|(_, v)| v.created_at)
.map(|(k, _)| *k)
{
entries.remove(&oldest_key);
}
}
}
entries.insert(
key,
CacheEntry {
result,
created_at: current_epoch_secs(),
},
);
}
fn clear(&self) {
self.entries.write().clear();
}
}
pub struct RuVectorEngine {
config: EngineWorkerConfig,
indexes: RwLock<HashMap<i32, CollectionIndex>>,
cache: QueryCache,
stats: EngineStats,
running: AtomicBool,
}
impl RuVectorEngine {
pub fn new(config: EngineWorkerConfig) -> Self {
Self {
cache: QueryCache::new(config.query_cache_size, config.query_cache_ttl_secs),
config,
indexes: RwLock::new(HashMap::new()),
stats: EngineStats::new(),
running: AtomicBool::new(false),
}
}
pub fn start(&self) {
self.running.store(true, Ordering::SeqCst);
}
pub fn stop(&self) {
self.running.store(false, Ordering::SeqCst);
}
pub fn is_running(&self) -> bool {
self.running.load(Ordering::SeqCst)
}
pub fn load_from_storage(&mut self) -> Result<(), String> {
pgrx::log!("Loading indexes from storage");
Ok(())
}
pub fn persist_to_storage(&self) {
pgrx::log!("Persisting indexes to storage");
}
pub fn search(&self, request: &SearchRequest) -> Result<SearchResult, String> {
let start = Instant::now();
if self.config.enable_query_cache {
let cache_key = compute_cache_key(request);
if let Some(mut result) = self.cache.get(cache_key) {
result.cache_hit = true;
self.stats.cache_hits.fetch_add(1, Ordering::Relaxed);
return Ok(result);
}
self.stats.cache_misses.fetch_add(1, Ordering::Relaxed);
}
let indexes = self.indexes.read();
let index = indexes
.get(&request.collection_id)
.ok_or_else(|| format!("Collection {} not found", request.collection_id))?;
if !index.is_loaded() {
return Err(format!(
"Index for collection {} not loaded",
request.collection_id
));
}
index.record_query();
let result = self.perform_search(index, request)?;
let elapsed = start.elapsed();
self.stats.total_searches.fetch_add(1, Ordering::Relaxed);
self.stats
.total_search_time_us
.fetch_add(elapsed.as_micros() as u64, Ordering::Relaxed);
if self.config.enable_query_cache {
let cache_key = compute_cache_key(request);
self.cache.put(cache_key, result.clone());
}
Ok(result)
}
fn perform_search(
&self,
index: &CollectionIndex,
request: &SearchRequest,
) -> Result<SearchResult, String> {
Ok(SearchResult {
ids: (0..request.k as i64).collect(),
distances: (0..request.k).map(|i| i as f32 * 0.1).collect(),
search_time_us: 100,
vectors_scanned: 1000,
cache_hit: false,
})
}
pub fn insert(&mut self, request: &InsertRequest) -> Result<(), String> {
let start = Instant::now();
let mut indexes = self.indexes.write();
let index = indexes.entry(request.collection_id).or_insert_with(|| {
CollectionIndex::new(request.collection_id, "hnsw", 0)
});
let count = request.vectors.len() as u64;
self.stats.total_inserts.fetch_add(count, Ordering::Relaxed);
self.stats
.total_insert_time_us
.fetch_add(start.elapsed().as_micros() as u64, Ordering::Relaxed);
self.cache.clear();
Ok(())
}
pub fn delete(&mut self, request: &DeleteRequest) -> Result<(), String> {
let start = Instant::now();
let indexes = self.indexes.read();
let _index = indexes
.get(&request.collection_id)
.ok_or_else(|| format!("Collection {} not found", request.collection_id))?;
let count = request.ids.len() as u64;
self.stats.total_deletes.fetch_add(count, Ordering::Relaxed);
self.stats
.total_delete_time_us
.fetch_add(start.elapsed().as_micros() as u64, Ordering::Relaxed);
self.cache.clear();
Ok(())
}
pub fn build_index(&mut self, request: &BuildIndexRequest) -> Result<(), String> {
let start = Instant::now();
pgrx::log!(
"Building {} index for collection {}",
request.index_type,
request.collection_id
);
let dimensions = 128; let index = CollectionIndex::new(request.collection_id, &request.index_type, dimensions);
let mut indexes = self.indexes.write();
indexes.insert(request.collection_id, index);
self.stats.indexes_built.fetch_add(1, Ordering::Relaxed);
self.stats
.total_build_time_us
.fetch_add(start.elapsed().as_micros() as u64, Ordering::Relaxed);
Ok(())
}
pub fn update_index(&mut self, request: &UpdateIndexRequest) -> Result<(), String> {
let start = Instant::now();
let mut indexes = self.indexes.write();
let _index = indexes
.get_mut(&request.collection_id)
.ok_or_else(|| format!("Collection {} not found", request.collection_id))?;
let count = request.vectors.len() as u64;
self.stats.total_updates.fetch_add(count, Ordering::Relaxed);
self.stats
.total_update_time_us
.fetch_add(start.elapsed().as_micros() as u64, Ordering::Relaxed);
Ok(())
}
pub fn stats(&self) -> EngineStatsSnapshot {
self.stats.snapshot()
}
pub fn index_count(&self) -> usize {
self.indexes.read().len()
}
pub fn memory_usage(&self) -> usize {
self.indexes
.read()
.values()
.map(|idx| idx.memory_bytes)
.sum()
}
}
pub struct EngineStats {
pub total_searches: AtomicU64,
pub total_search_time_us: AtomicU64,
pub total_inserts: AtomicU64,
pub total_insert_time_us: AtomicU64,
pub total_deletes: AtomicU64,
pub total_delete_time_us: AtomicU64,
pub total_updates: AtomicU64,
pub total_update_time_us: AtomicU64,
pub indexes_built: AtomicU64,
pub total_build_time_us: AtomicU64,
pub cache_hits: AtomicU64,
pub cache_misses: AtomicU64,
}
impl EngineStats {
pub fn new() -> Self {
Self {
total_searches: AtomicU64::new(0),
total_search_time_us: AtomicU64::new(0),
total_inserts: AtomicU64::new(0),
total_insert_time_us: AtomicU64::new(0),
total_deletes: AtomicU64::new(0),
total_delete_time_us: AtomicU64::new(0),
total_updates: AtomicU64::new(0),
total_update_time_us: AtomicU64::new(0),
indexes_built: AtomicU64::new(0),
total_build_time_us: AtomicU64::new(0),
cache_hits: AtomicU64::new(0),
cache_misses: AtomicU64::new(0),
}
}
pub fn snapshot(&self) -> EngineStatsSnapshot {
let searches = self.total_searches.load(Ordering::Relaxed);
let search_time = self.total_search_time_us.load(Ordering::Relaxed);
let inserts = self.total_inserts.load(Ordering::Relaxed);
let cache_hits = self.cache_hits.load(Ordering::Relaxed);
let cache_misses = self.cache_misses.load(Ordering::Relaxed);
EngineStatsSnapshot {
total_searches: searches,
total_search_time_us: search_time,
avg_search_time_us: if searches > 0 {
search_time / searches
} else {
0
},
total_inserts: inserts,
total_insert_time_us: self.total_insert_time_us.load(Ordering::Relaxed),
total_deletes: self.total_deletes.load(Ordering::Relaxed),
total_updates: self.total_updates.load(Ordering::Relaxed),
indexes_built: self.indexes_built.load(Ordering::Relaxed),
cache_hits,
cache_misses,
cache_hit_rate: if cache_hits + cache_misses > 0 {
cache_hits as f64 / (cache_hits + cache_misses) as f64
} else {
0.0
},
}
}
}
impl Default for EngineStats {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EngineStatsSnapshot {
pub total_searches: u64,
pub total_search_time_us: u64,
pub avg_search_time_us: u64,
pub total_inserts: u64,
pub total_insert_time_us: u64,
pub total_deletes: u64,
pub total_updates: u64,
pub indexes_built: u64,
pub cache_hits: u64,
pub cache_misses: u64,
pub cache_hit_rate: f64,
}
pub struct EngineWorker {
id: u64,
engine: RuVectorEngine,
}
impl EngineWorker {
pub fn new(id: u64) -> Self {
let config = get_engine_config();
Self {
id,
engine: RuVectorEngine::new(config),
}
}
pub fn run(&mut self) {
pgrx::log!("Engine worker {} starting", self.id);
self.engine.start();
if let Err(e) = self.engine.load_from_storage() {
pgrx::warning!("Failed to load indexes: {}", e);
}
while self.engine.is_running() {
if get_lifecycle_manager().is_shutdown_requested() {
break;
}
self.process_work_queue();
std::thread::sleep(Duration::from_millis(1));
}
self.engine.persist_to_storage();
self.engine.stop();
pgrx::log!("Engine worker {} stopped", self.id);
}
fn process_work_queue(&mut self) {
let shmem = get_shared_memory();
while let Some(work_item) = shmem.work_queue.try_pop() {
if shmem.is_cancelled(work_item.request_id) {
continue;
}
if work_item.deadline_ms > 0 && current_epoch_ms() > work_item.deadline_ms {
let result = WorkResult {
request_id: work_item.request_id,
status: ResultStatus::Timeout,
data: Vec::new(),
processing_time_us: 0,
};
shmem.result_queue.push(result);
continue;
}
let start = Instant::now();
let result = self.process_operation(&work_item);
let processing_time_us = start.elapsed().as_micros() as u64;
let work_result = match result {
Ok(data) => {
shmem
.stats
.record_success(processing_time_us, data.len() as u64);
WorkResult {
request_id: work_item.request_id,
status: ResultStatus::Success,
data,
processing_time_us,
}
}
Err(e) => {
shmem.stats.record_failure();
WorkResult {
request_id: work_item.request_id,
status: ResultStatus::Error,
data: e.into_bytes(),
processing_time_us,
}
}
};
shmem.result_queue.push(work_result);
}
}
fn process_operation(&mut self, work_item: &WorkItem) -> Result<Vec<u8>, String> {
match &work_item.operation {
Operation::Search(req) => {
let result = self.engine.search(req)?;
Ok(result.to_bytes())
}
Operation::Insert(req) => {
self.engine.insert(req)?;
Ok(Vec::new())
}
Operation::Delete(req) => {
self.engine.delete(req)?;
Ok(Vec::new())
}
Operation::BuildIndex(req) => {
self.engine.build_index(req)?;
Ok(Vec::new())
}
Operation::UpdateIndex(req) => {
self.engine.update_index(req)?;
Ok(Vec::new())
}
Operation::LargePayloadRef(payload_ref) => {
let shmem = get_shared_memory();
let data = shmem
.large_payload_segment
.read(payload_ref.offset as usize, payload_ref.length as usize)?;
let operation: Operation = bincode::deserialize(&data)
.map_err(|e| format!("Failed to decode operation: {}", e))?;
let decoded_item = WorkItem {
operation,
..work_item.clone()
};
self.process_operation(&decoded_item)
}
Operation::Ping => Ok(b"pong".to_vec()),
}
}
}
#[pg_guard]
pub extern "C" fn ruvector_engine_worker_main(arg: pg_sys::Datum) {
let worker_id = arg.value() as u64;
pgrx::log!("RuVector engine worker {} starting", worker_id);
let mut worker = EngineWorker::new(worker_id);
get_lifecycle_manager().update_status(worker_id, WorkerStatus::Running);
worker.run();
get_lifecycle_manager().update_status(worker_id, WorkerStatus::Stopped);
pgrx::log!("RuVector engine worker {} stopped", worker_id);
}
fn current_epoch_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
}
fn current_epoch_ms() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as u64
}
fn compute_cache_key(request: &SearchRequest) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
request.collection_id.hash(&mut hasher);
request.k.hash(&mut hasher);
request.ef_search.hash(&mut hasher);
request.use_gnn.hash(&mut hasher);
request.filter.hash(&mut hasher);
for &v in &request.query {
v.to_bits().hash(&mut hasher);
}
hasher.finish()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_engine_config_default() {
let config = EngineWorkerConfig::default();
assert_eq!(config.max_concurrent_searches, 64);
assert!(config.enable_simd);
assert!(config.enable_query_cache);
}
#[test]
fn test_search_result_serialization() {
let result = SearchResult {
ids: vec![1, 2, 3],
distances: vec![0.1, 0.2, 0.3],
search_time_us: 100,
vectors_scanned: 1000,
cache_hit: false,
};
let bytes = result.to_bytes();
let decoded = SearchResult::from_bytes(&bytes).unwrap();
assert_eq!(decoded.ids, result.ids);
assert_eq!(decoded.distances, result.distances);
}
#[test]
fn test_collection_index() {
let index = CollectionIndex::new(1, "hnsw", 128);
assert_eq!(index.collection_id, 1);
assert_eq!(index.index_type, "hnsw");
assert!(!index.is_loaded());
index.record_query();
assert_eq!(index.get_query_count(), 1);
}
#[test]
fn test_query_cache() {
let cache = QueryCache::new(10, 60);
let result = SearchResult {
ids: vec![1],
distances: vec![0.1],
search_time_us: 100,
vectors_scanned: 10,
cache_hit: false,
};
cache.put(123, result.clone());
let cached = cache.get(123).unwrap();
assert_eq!(cached.ids, result.ids);
assert!(cache.get(456).is_none());
}
#[test]
fn test_engine_basic() {
let config = EngineWorkerConfig::default();
let mut engine = RuVectorEngine::new(config);
engine.start();
assert!(engine.is_running());
engine.stop();
assert!(!engine.is_running());
}
#[test]
fn test_cache_key_computation() {
let req1 = SearchRequest {
collection_id: 1,
query: vec![1.0, 2.0, 3.0],
k: 10,
ef_search: Some(50),
filter: None,
use_gnn: false,
};
let req2 = SearchRequest {
collection_id: 1,
query: vec![1.0, 2.0, 3.0],
k: 10,
ef_search: Some(50),
filter: None,
use_gnn: false,
};
let req3 = SearchRequest {
collection_id: 2, query: vec![1.0, 2.0, 3.0],
k: 10,
ef_search: Some(50),
filter: None,
use_gnn: false,
};
assert_eq!(compute_cache_key(&req1), compute_cache_key(&req2));
assert_ne!(compute_cache_key(&req1), compute_cache_key(&req3));
}
}