use parking_lot::RwLock;
use pgrx::prelude::*;
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::sync::OnceLock;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
pub const MAX_INLINE_SIZE: usize = 64 * 1024;
pub const MAX_LARGE_PAYLOAD_SIZE: usize = 16 * 1024 * 1024;
pub const QUEUE_SIZE: usize = 1024;
pub const MAX_REQUEST_TIMEOUT_MS: u64 = 30_000;
pub const MAX_SUBMIT_RETRIES: u32 = 10;
pub const MAX_COLLECTIONS: usize = 256;
pub const SHMEM_VERSION: u32 = 1;
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct PayloadRef {
pub offset: u32,
pub length: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Operation {
Search(SearchRequest),
Insert(InsertRequest),
Delete(DeleteRequest),
BuildIndex(BuildIndexRequest),
UpdateIndex(UpdateIndexRequest),
LargePayloadRef(PayloadRef),
Ping,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchRequest {
pub collection_id: i32,
pub query: Vec<f32>,
pub k: usize,
pub ef_search: Option<usize>,
pub filter: Option<String>,
pub use_gnn: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InsertRequest {
pub collection_id: i32,
pub vectors: Vec<Vec<f32>>,
pub ids: Vec<i64>,
pub metadata: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeleteRequest {
pub collection_id: i32,
pub ids: Vec<i64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BuildIndexRequest {
pub collection_id: i32,
pub index_type: String,
pub params: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpdateIndexRequest {
pub collection_id: i32,
pub vectors: Vec<Vec<f32>>,
pub ids: Vec<i64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkItem {
pub request_id: u64,
pub operation: Operation,
pub priority: u8,
pub deadline_ms: u64,
pub backend_pid: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkResult {
pub request_id: u64,
pub status: ResultStatus,
pub data: Vec<u8>,
pub processing_time_us: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ResultStatus {
Success,
Error,
Timeout,
Cancelled,
QueueFull,
}
pub struct WorkQueue {
head: AtomicU64,
tail: AtomicU64,
buffer: RwLock<Vec<Option<WorkItem>>>,
capacity: usize,
}
impl WorkQueue {
pub fn new(capacity: usize) -> Self {
let buffer = (0..capacity).map(|_| None).collect();
Self {
head: AtomicU64::new(0),
tail: AtomicU64::new(0),
buffer: RwLock::new(buffer),
capacity,
}
}
pub fn push(&self, item: WorkItem) -> Result<(), QueueError> {
loop {
let tail = self.tail.load(Ordering::Acquire);
let head = self.head.load(Ordering::Acquire);
if tail.wrapping_sub(head) >= self.capacity as u64 {
return Err(QueueError::Full);
}
if self
.tail
.compare_exchange_weak(
tail,
tail.wrapping_add(1),
Ordering::AcqRel,
Ordering::Relaxed,
)
.is_ok()
{
let slot = (tail % self.capacity as u64) as usize;
let mut buffer = self.buffer.write();
buffer[slot] = Some(item);
return Ok(());
}
}
}
pub fn try_pop(&self) -> Option<WorkItem> {
loop {
let head = self.head.load(Ordering::Acquire);
let tail = self.tail.load(Ordering::Acquire);
if head >= tail {
return None;
}
let slot = (head % self.capacity as u64) as usize;
if self
.head
.compare_exchange_weak(
head,
head.wrapping_add(1),
Ordering::AcqRel,
Ordering::Relaxed,
)
.is_ok()
{
let mut buffer = self.buffer.write();
return buffer[slot].take();
}
}
}
pub fn len(&self) -> usize {
let tail = self.tail.load(Ordering::Relaxed);
let head = self.head.load(Ordering::Relaxed);
tail.wrapping_sub(head) as usize
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
pub struct ResultQueue {
results: RwLock<std::collections::HashMap<u64, WorkResult>>,
pending: RwLock<std::collections::HashSet<u64>>,
}
impl ResultQueue {
pub fn new() -> Self {
Self {
results: RwLock::new(std::collections::HashMap::new()),
pending: RwLock::new(std::collections::HashSet::new()),
}
}
pub fn push(&self, result: WorkResult) {
let request_id = result.request_id;
let mut results = self.results.write();
let mut pending = self.pending.write();
results.insert(request_id, result);
pending.remove(&request_id);
}
pub fn try_get(&self, request_id: u64) -> Option<WorkResult> {
let mut results = self.results.write();
results.remove(&request_id)
}
pub fn mark_pending(&self, request_id: u64) {
let mut pending = self.pending.write();
pending.insert(request_id);
}
pub fn is_pending(&self, request_id: u64) -> bool {
let pending = self.pending.read();
pending.contains(&request_id)
}
}
impl Default for ResultQueue {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueueError {
Full,
Empty,
}
const SLOT_SIZE: usize = 64 * 1024;
const NUM_SLOTS: usize = MAX_LARGE_PAYLOAD_SIZE / SLOT_SIZE;
pub struct LargePayloadSegment {
size: usize,
alloc_bitmap: Vec<AtomicU64>,
data: RwLock<Vec<u8>>,
}
impl LargePayloadSegment {
pub fn new() -> Self {
let bitmap_size = (NUM_SLOTS + 63) / 64;
let alloc_bitmap = (0..bitmap_size).map(|_| AtomicU64::new(0)).collect();
Self {
size: MAX_LARGE_PAYLOAD_SIZE,
alloc_bitmap,
data: RwLock::new(vec![0u8; MAX_LARGE_PAYLOAD_SIZE]),
}
}
pub fn allocate(&self, size: usize) -> Option<PayloadRef> {
if size > MAX_LARGE_PAYLOAD_SIZE {
return None;
}
let slots_needed = size.div_ceil(SLOT_SIZE);
for start_slot in 0..=(NUM_SLOTS - slots_needed) {
if self.try_allocate_range(start_slot, slots_needed) {
return Some(PayloadRef {
offset: (start_slot * SLOT_SIZE) as u32,
length: size as u32,
});
}
}
None
}
fn try_allocate_range(&self, start: usize, count: usize) -> bool {
for slot in start..(start + count) {
let word = slot / 64;
let bit = slot % 64;
let bitmap = self.alloc_bitmap[word].load(Ordering::Acquire);
if bitmap & (1u64 << bit) != 0 {
return false;
}
}
for slot in start..(start + count) {
let word = slot / 64;
let bit = slot % 64;
let mask = 1u64 << bit;
loop {
let current = self.alloc_bitmap[word].load(Ordering::Acquire);
if current & mask != 0 {
for s in start..slot {
let w = s / 64;
let b = s % 64;
self.alloc_bitmap[w].fetch_and(!(1u64 << b), Ordering::Release);
}
return false;
}
if self.alloc_bitmap[word]
.compare_exchange_weak(
current,
current | mask,
Ordering::AcqRel,
Ordering::Relaxed,
)
.is_ok()
{
break;
}
}
}
true
}
pub fn free(&self, payload_ref: &PayloadRef) {
let start_slot = payload_ref.offset as usize / SLOT_SIZE;
let slots = (payload_ref.length as usize).div_ceil(SLOT_SIZE);
for slot in start_slot..(start_slot + slots) {
let word = slot / 64;
let bit = slot % 64;
self.alloc_bitmap[word].fetch_and(!(1u64 << bit), Ordering::Release);
}
}
pub fn write(&self, offset: usize, data: &[u8]) -> Result<(), String> {
if offset + data.len() > self.size {
return Err("Write exceeds segment bounds".to_string());
}
let mut buffer = self.data.write();
buffer[offset..offset + data.len()].copy_from_slice(data);
Ok(())
}
pub fn read(&self, offset: usize, length: usize) -> Result<Vec<u8>, String> {
if offset + length > self.size {
return Err("Read exceeds segment bounds".to_string());
}
let buffer = self.data.read();
Ok(buffer[offset..offset + length].to_vec())
}
pub fn bytes_used(&self) -> usize {
let mut count = 0;
for bitmap in &self.alloc_bitmap {
count += bitmap.load(Ordering::Relaxed).count_ones() as usize;
}
count * SLOT_SIZE
}
}
impl Default for LargePayloadSegment {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Default)]
pub struct IndexState {
pub collection_id: i32,
pub loaded: bool,
pub vector_count: u64,
pub size_bytes: u64,
pub last_query_at: u64,
pub query_count: u64,
}
#[derive(Debug, Clone, Default)]
pub struct IntegrityPermissions {
pub collection_id: i32,
pub state: u8, pub lambda_cut: f64,
pub allow_reads: bool,
pub allow_writes: bool,
pub allow_deletes: bool,
pub last_update: u64,
}
#[derive(Debug, Default)]
pub struct GlobalStats {
pub total_requests: AtomicU64,
pub successful_requests: AtomicU64,
pub failed_requests: AtomicU64,
pub timeouts: AtomicU64,
pub queue_full_events: AtomicU64,
pub bytes_processed: AtomicU64,
pub total_processing_time_us: AtomicU64,
}
impl GlobalStats {
pub fn new() -> Self {
Self::default()
}
pub fn record_success(&self, processing_time_us: u64, bytes: u64) {
self.total_requests.fetch_add(1, Ordering::Relaxed);
self.successful_requests.fetch_add(1, Ordering::Relaxed);
self.total_processing_time_us
.fetch_add(processing_time_us, Ordering::Relaxed);
self.bytes_processed.fetch_add(bytes, Ordering::Relaxed);
}
pub fn record_failure(&self) {
self.total_requests.fetch_add(1, Ordering::Relaxed);
self.failed_requests.fetch_add(1, Ordering::Relaxed);
}
pub fn record_timeout(&self) {
self.total_requests.fetch_add(1, Ordering::Relaxed);
self.timeouts.fetch_add(1, Ordering::Relaxed);
}
pub fn record_queue_full(&self) {
self.queue_full_events.fetch_add(1, Ordering::Relaxed);
}
pub fn to_json(&self) -> serde_json::Value {
let total = self.total_requests.load(Ordering::Relaxed);
let successful = self.successful_requests.load(Ordering::Relaxed);
let total_time = self.total_processing_time_us.load(Ordering::Relaxed);
serde_json::json!({
"total_requests": total,
"successful_requests": successful,
"failed_requests": self.failed_requests.load(Ordering::Relaxed),
"timeouts": self.timeouts.load(Ordering::Relaxed),
"queue_full_events": self.queue_full_events.load(Ordering::Relaxed),
"bytes_processed": self.bytes_processed.load(Ordering::Relaxed),
"total_processing_time_us": total_time,
"avg_processing_time_us": if successful > 0 { total_time / successful } else { 0 },
})
}
}
pub struct SharedMemoryLayout {
pub version: AtomicU32,
pub init_lock: AtomicU32,
pub work_queue: WorkQueue,
pub result_queue: ResultQueue,
pub large_payload_segment: LargePayloadSegment,
pub index_states: RwLock<Vec<IndexState>>,
pub integrity_states: RwLock<Vec<IntegrityPermissions>>,
pub stats: GlobalStats,
next_request_id: AtomicU64,
cancelled: RwLock<std::collections::HashSet<u64>>,
}
impl SharedMemoryLayout {
pub fn new() -> Self {
Self {
version: AtomicU32::new(SHMEM_VERSION),
init_lock: AtomicU32::new(0),
work_queue: WorkQueue::new(QUEUE_SIZE),
result_queue: ResultQueue::new(),
large_payload_segment: LargePayloadSegment::new(),
index_states: RwLock::new(vec![IndexState::default(); MAX_COLLECTIONS]),
integrity_states: RwLock::new(vec![IntegrityPermissions::default(); MAX_COLLECTIONS]),
stats: GlobalStats::new(),
next_request_id: AtomicU64::new(1),
cancelled: RwLock::new(std::collections::HashSet::new()),
}
}
pub fn next_request_id(&self) -> u64 {
self.next_request_id.fetch_add(1, Ordering::SeqCst)
}
pub fn cancel_request(&self, request_id: u64) {
let mut cancelled = self.cancelled.write();
cancelled.insert(request_id);
}
pub fn is_cancelled(&self, request_id: u64) -> bool {
let cancelled = self.cancelled.read();
cancelled.contains(&request_id)
}
pub fn cleanup_cancelled(&self, max_age_ms: u64) {
let mut cancelled = self.cancelled.write();
if cancelled.len() > 10000 {
cancelled.clear();
}
}
pub fn signal_engine(&self) {
}
pub fn update_integrity_permissions(
&self,
collection_id: i32,
permissions: &IntegrityPermissions,
) {
if (collection_id as usize) < MAX_COLLECTIONS {
let mut states = self.integrity_states.write();
states[collection_id as usize] = permissions.clone();
}
}
pub fn get_integrity_permissions(&self, collection_id: i32) -> Option<IntegrityPermissions> {
if (collection_id as usize) < MAX_COLLECTIONS {
let states = self.integrity_states.read();
Some(states[collection_id as usize].clone())
} else {
None
}
}
}
impl Default for SharedMemoryLayout {
fn default() -> Self {
Self::new()
}
}
static SHARED_MEMORY: OnceLock<SharedMemoryLayout> = OnceLock::new();
pub fn get_shared_memory() -> &'static SharedMemoryLayout {
SHARED_MEMORY.get_or_init(SharedMemoryLayout::new)
}
pub struct SharedMemory;
impl SharedMemory {
pub fn get() -> &'static SharedMemoryLayout {
get_shared_memory()
}
pub fn attach() -> Result<&'static SharedMemoryLayout, String> {
Ok(get_shared_memory())
}
}
pub fn init_shared_memory() -> Result<(), String> {
let shmem = get_shared_memory();
if shmem.version.load(Ordering::SeqCst) != SHMEM_VERSION {
return Err("Shared memory version mismatch".to_string());
}
pgrx::log!("Shared memory initialized (version {})", SHMEM_VERSION);
Ok(())
}
fn current_epoch_ms() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as u64
}
pub fn submit_and_wait(operation: Operation, timeout_ms: u64) -> Result<WorkResult, IpcError> {
let shmem = get_shared_memory();
let request_id = shmem.next_request_id();
let (final_operation, payload_ref) = prepare_operation(operation, shmem)?;
let work_item = WorkItem {
request_id,
operation: final_operation,
priority: 128, deadline_ms: current_epoch_ms() + timeout_ms.min(MAX_REQUEST_TIMEOUT_MS),
backend_pid: unsafe { pg_sys::MyProcPid },
};
let mut retry_count = 0;
loop {
match shmem.work_queue.push(work_item.clone()) {
Ok(()) => break,
Err(QueueError::Full) => {
retry_count += 1;
if retry_count > MAX_SUBMIT_RETRIES {
if let Some(ref pr) = payload_ref {
shmem.large_payload_segment.free(pr);
}
shmem.stats.record_queue_full();
return Err(IpcError::QueueFull);
}
std::thread::sleep(Duration::from_millis(1 << retry_count.min(6)));
}
Err(QueueError::Empty) => unreachable!(),
}
}
shmem.result_queue.mark_pending(request_id);
shmem.signal_engine();
let deadline = Instant::now() + Duration::from_millis(timeout_ms);
loop {
if let Some(result) = shmem.result_queue.try_get(request_id) {
if let Some(ref pr) = payload_ref {
shmem.large_payload_segment.free(pr);
}
return Ok(result);
}
if Instant::now() > deadline {
shmem.cancel_request(request_id);
if let Some(ref pr) = payload_ref {
shmem.large_payload_segment.free(pr);
}
shmem.stats.record_timeout();
return Err(IpcError::Timeout);
}
if unsafe { pg_sys::QueryCancelPending != 0 } {
shmem.cancel_request(request_id);
if let Some(ref pr) = payload_ref {
shmem.large_payload_segment.free(pr);
}
return Err(IpcError::Cancelled);
}
std::thread::sleep(Duration::from_millis(1));
}
}
fn prepare_operation(
operation: Operation,
shmem: &SharedMemoryLayout,
) -> Result<(Operation, Option<PayloadRef>), IpcError> {
let serialized =
bincode::serialize(&operation).map_err(|e| IpcError::SerializationError(e.to_string()))?;
if serialized.len() <= MAX_INLINE_SIZE {
return Ok((operation, None));
}
let payload_ref = shmem
.large_payload_segment
.allocate(serialized.len())
.ok_or(IpcError::PayloadTooLarge)?;
shmem
.large_payload_segment
.write(payload_ref.offset as usize, &serialized)
.map_err(IpcError::SharedMemoryError)?;
Ok((Operation::LargePayloadRef(payload_ref), Some(payload_ref)))
}
#[derive(Debug, Clone)]
pub enum IpcError {
QueueFull,
Timeout,
Cancelled,
PayloadTooLarge,
SharedMemoryError(String),
SerializationError(String),
}
impl std::fmt::Display for IpcError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
IpcError::QueueFull => write!(f, "Work queue is full"),
IpcError::Timeout => write!(f, "Operation timed out"),
IpcError::Cancelled => write!(f, "Operation was cancelled"),
IpcError::PayloadTooLarge => write!(f, "Payload too large for shared segment"),
IpcError::SharedMemoryError(e) => write!(f, "Shared memory error: {}", e),
IpcError::SerializationError(e) => write!(f, "Serialization error: {}", e),
}
}
}
impl std::error::Error for IpcError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_work_queue_basic() {
let queue = WorkQueue::new(16);
let item = WorkItem {
request_id: 1,
operation: Operation::Ping,
priority: 128,
deadline_ms: 0,
backend_pid: 0,
};
assert!(queue.push(item.clone()).is_ok());
assert_eq!(queue.len(), 1);
let popped = queue.try_pop().unwrap();
assert_eq!(popped.request_id, 1);
assert!(queue.is_empty());
}
#[test]
fn test_work_queue_full() {
let queue = WorkQueue::new(2);
let item = WorkItem {
request_id: 1,
operation: Operation::Ping,
priority: 128,
deadline_ms: 0,
backend_pid: 0,
};
assert!(queue.push(item.clone()).is_ok());
assert!(queue.push(item.clone()).is_ok());
assert_eq!(queue.push(item.clone()), Err(QueueError::Full));
}
#[test]
fn test_result_queue() {
let queue = ResultQueue::new();
queue.mark_pending(1);
assert!(queue.is_pending(1));
let result = WorkResult {
request_id: 1,
status: ResultStatus::Success,
data: vec![],
processing_time_us: 100,
};
queue.push(result);
assert!(!queue.is_pending(1));
let retrieved = queue.try_get(1).unwrap();
assert_eq!(retrieved.request_id, 1);
}
#[test]
fn test_large_payload_segment() {
let segment = LargePayloadSegment::new();
let payload_ref = segment.allocate(100 * 1024).unwrap();
assert_eq!(payload_ref.offset, 0);
assert_eq!(payload_ref.length, 100 * 1024);
let data = vec![42u8; 1000];
segment.write(0, &data).unwrap();
let read_data = segment.read(0, 1000).unwrap();
assert_eq!(data, read_data);
segment.free(&payload_ref);
assert_eq!(segment.bytes_used(), 0);
}
#[test]
fn test_global_stats() {
let stats = GlobalStats::new();
stats.record_success(100, 1000);
stats.record_success(200, 2000);
stats.record_failure();
stats.record_timeout();
let json = stats.to_json();
assert_eq!(json["total_requests"], 4);
assert_eq!(json["successful_requests"], 2);
assert_eq!(json["failed_requests"], 1);
assert_eq!(json["timeouts"], 1);
}
}