use parking_lot::{Mutex, RwLock};
use serde::{Deserialize, Serialize};
use std::cmp::Ordering as CmpOrdering;
use std::collections::BinaryHeap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::OnceLock;
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum TaskType {
Query,
Insert,
Delete,
Maintenance,
Training,
Integrity,
IndexBuild,
StatsCollection,
}
impl std::fmt::Display for TaskType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TaskType::Query => write!(f, "query"),
TaskType::Insert => write!(f, "insert"),
TaskType::Delete => write!(f, "delete"),
TaskType::Maintenance => write!(f, "maintenance"),
TaskType::Training => write!(f, "training"),
TaskType::Integrity => write!(f, "integrity"),
TaskType::IndexBuild => write!(f, "index_build"),
TaskType::StatsCollection => write!(f, "stats_collection"),
}
}
}
#[derive(
Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize,
)]
pub enum TaskPriority {
Critical = 0,
High = 1,
#[default]
Medium = 2,
Low = 3,
}
impl std::fmt::Display for TaskPriority {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TaskPriority::Critical => write!(f, "critical"),
TaskPriority::High => write!(f, "high"),
TaskPriority::Medium => write!(f, "medium"),
TaskPriority::Low => write!(f, "low"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Task {
pub id: u64,
pub task_type: TaskType,
pub priority: TaskPriority,
pub data: Vec<u8>,
pub created_at: u64,
pub deadline_ms: u64,
pub max_retries: u32,
pub retry_count: u32,
pub collection_id: Option<i32>,
pub dependencies: Vec<u64>,
}
impl Task {
pub fn new(task_type: TaskType, data: Vec<u8>) -> Self {
static NEXT_ID: AtomicU64 = AtomicU64::new(1);
Self {
id: NEXT_ID.fetch_add(1, Ordering::SeqCst),
task_type,
priority: TaskPriority::default(),
data,
created_at: current_epoch_ms(),
deadline_ms: 0,
max_retries: 3,
retry_count: 0,
collection_id: None,
dependencies: Vec::new(),
}
}
pub fn with_priority(mut self, priority: TaskPriority) -> Self {
self.priority = priority;
self
}
pub fn with_deadline(mut self, deadline_ms: u64) -> Self {
self.deadline_ms = deadline_ms;
self
}
pub fn with_max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = max_retries;
self
}
pub fn with_collection(mut self, collection_id: i32) -> Self {
self.collection_id = Some(collection_id);
self
}
pub fn with_dependencies(mut self, deps: Vec<u64>) -> Self {
self.dependencies = deps;
self
}
pub fn is_expired(&self) -> bool {
if self.deadline_ms == 0 {
return false;
}
current_epoch_ms() > self.deadline_ms
}
pub fn can_retry(&self) -> bool {
self.retry_count < self.max_retries
}
pub fn increment_retry(&mut self) {
self.retry_count += 1;
}
}
#[derive(Debug)]
struct PrioritizedTask {
task: Task,
effective_priority: i64,
}
impl PrioritizedTask {
fn new(task: Task) -> Self {
let age_bonus = (current_epoch_ms() - task.created_at) / 1000;
let deadline_urgency = if task.deadline_ms > 0 {
let remaining = task.deadline_ms.saturating_sub(current_epoch_ms());
if remaining < 1000 {
-100 } else if remaining < 5000 {
-50
} else if remaining < 30000 {
-10
} else {
0
}
} else {
0
};
let effective_priority =
(task.priority as i64 * 100) - (age_bonus as i64) + deadline_urgency;
Self {
task,
effective_priority,
}
}
}
impl PartialEq for PrioritizedTask {
fn eq(&self, other: &Self) -> bool {
self.effective_priority == other.effective_priority
}
}
impl Eq for PrioritizedTask {}
impl PartialOrd for PrioritizedTask {
fn partial_cmp(&self, other: &Self) -> Option<CmpOrdering> {
Some(self.cmp(other))
}
}
impl Ord for PrioritizedTask {
fn cmp(&self, other: &Self) -> CmpOrdering {
other.effective_priority.cmp(&self.effective_priority)
}
}
pub struct TaskQueue {
queue: Mutex<BinaryHeap<PrioritizedTask>>,
worker_queues: RwLock<Vec<Mutex<Vec<Task>>>>,
completed: RwLock<std::collections::HashSet<u64>>,
failed: RwLock<std::collections::HashSet<u64>>,
stats: QueueStats,
max_size: usize,
}
impl TaskQueue {
pub fn new(max_size: usize) -> Self {
Self {
queue: Mutex::new(BinaryHeap::with_capacity(max_size)),
worker_queues: RwLock::new(Vec::new()),
completed: RwLock::new(std::collections::HashSet::new()),
failed: RwLock::new(std::collections::HashSet::new()),
stats: QueueStats::new(),
max_size,
}
}
pub fn add_worker(&self) -> usize {
let mut queues = self.worker_queues.write();
let worker_id = queues.len();
queues.push(Mutex::new(Vec::new()));
worker_id
}
pub fn remove_worker(&self, worker_id: usize) {
let queues = self.worker_queues.read();
if worker_id < queues.len() {
let worker_queue = queues[worker_id].lock();
let mut main_queue = self.queue.lock();
for task in worker_queue.iter() {
main_queue.push(PrioritizedTask::new(task.clone()));
}
}
}
pub fn enqueue(&self, task: Task) -> Result<u64, QueueError> {
if self.len() >= self.max_size {
self.stats.rejected.fetch_add(1, Ordering::Relaxed);
return Err(QueueError::QueueFull);
}
if !self.dependencies_satisfied(&task) {
self.stats.pending.fetch_add(1, Ordering::Relaxed);
}
let task_id = task.id;
let mut queue = self.queue.lock();
queue.push(PrioritizedTask::new(task));
self.stats.enqueued.fetch_add(1, Ordering::Relaxed);
self.update_queue_depth();
Ok(task_id)
}
pub fn dequeue(&self) -> Option<Task> {
let mut queue = self.queue.lock();
while let Some(prioritized) = queue.pop() {
if prioritized.task.is_expired() {
self.stats.expired.fetch_add(1, Ordering::Relaxed);
continue;
}
if !self.dependencies_satisfied(&prioritized.task) {
queue.push(prioritized);
continue;
}
self.stats.dequeued.fetch_add(1, Ordering::Relaxed);
self.update_queue_depth();
return Some(prioritized.task);
}
None
}
pub fn dequeue_for_worker(&self, worker_id: usize) -> Option<Task> {
{
let queues = self.worker_queues.read();
if worker_id < queues.len() {
if let Some(task) = queues[worker_id].lock().pop() {
return Some(task);
}
}
}
if let Some(task) = self.dequeue() {
return Some(task);
}
self.steal_work(worker_id)
}
fn steal_work(&self, worker_id: usize) -> Option<Task> {
let queues = self.worker_queues.read();
for (i, worker_queue) in queues.iter().enumerate() {
if i == worker_id {
continue;
}
let mut queue = worker_queue.lock();
if !queue.is_empty() {
let steal_count = queue.len() / 2;
if steal_count > 0 {
let stolen: Vec<_> = queue.drain(..steal_count).collect();
if !stolen.is_empty() {
self.stats
.stolen
.fetch_add(stolen.len() as u64, Ordering::Relaxed);
return Some(stolen.into_iter().next().unwrap());
}
}
}
}
None
}
fn dependencies_satisfied(&self, task: &Task) -> bool {
if task.dependencies.is_empty() {
return true;
}
let completed = self.completed.read();
task.dependencies
.iter()
.all(|dep_id| completed.contains(dep_id))
}
pub fn mark_completed(&self, task_id: u64) {
self.completed.write().insert(task_id);
self.stats.completed.fetch_add(1, Ordering::Relaxed);
}
pub fn mark_failed(&self, task_id: u64) {
self.failed.write().insert(task_id);
self.stats.failed.fetch_add(1, Ordering::Relaxed);
}
pub fn requeue_for_retry(&self, mut task: Task) -> Result<(), QueueError> {
if !task.can_retry() {
return Err(QueueError::MaxRetriesExceeded);
}
task.increment_retry();
self.stats.retried.fetch_add(1, Ordering::Relaxed);
let mut queue = self.queue.lock();
queue.push(PrioritizedTask::new(task));
Ok(())
}
pub fn len(&self) -> usize {
self.queue.lock().len()
}
pub fn is_empty(&self) -> bool {
self.queue.lock().is_empty()
}
fn update_queue_depth(&self) {
let depth = self.len() as u64;
self.stats.current_depth.store(depth, Ordering::Relaxed);
self.stats.max_depth.fetch_max(depth, Ordering::Relaxed);
}
pub fn stats(&self) -> QueueStatsSnapshot {
self.stats.snapshot()
}
pub fn clear_tracking(&self) {
self.completed.write().clear();
self.failed.write().clear();
}
pub fn tasks_by_type(&self, task_type: TaskType) -> Vec<Task> {
let queue = self.queue.lock();
queue
.iter()
.filter(|pt| pt.task.task_type == task_type)
.map(|pt| pt.task.clone())
.collect()
}
pub fn cancel(&self, task_id: u64) -> bool {
let mut queue = self.queue.lock();
let initial_len = queue.len();
let remaining: Vec<_> = queue.drain().filter(|pt| pt.task.id != task_id).collect();
for pt in remaining {
queue.push(pt);
}
if queue.len() < initial_len {
self.stats.cancelled.fetch_add(1, Ordering::Relaxed);
true
} else {
false
}
}
}
impl Default for TaskQueue {
fn default() -> Self {
Self::new(10000)
}
}
pub struct QueueStats {
pub enqueued: AtomicU64,
pub dequeued: AtomicU64,
pub completed: AtomicU64,
pub failed: AtomicU64,
pub expired: AtomicU64,
pub rejected: AtomicU64,
pub retried: AtomicU64,
pub stolen: AtomicU64,
pub cancelled: AtomicU64,
pub pending: AtomicU64,
pub current_depth: AtomicU64,
pub max_depth: AtomicU64,
}
impl QueueStats {
pub fn new() -> Self {
Self {
enqueued: AtomicU64::new(0),
dequeued: AtomicU64::new(0),
completed: AtomicU64::new(0),
failed: AtomicU64::new(0),
expired: AtomicU64::new(0),
rejected: AtomicU64::new(0),
retried: AtomicU64::new(0),
stolen: AtomicU64::new(0),
cancelled: AtomicU64::new(0),
pending: AtomicU64::new(0),
current_depth: AtomicU64::new(0),
max_depth: AtomicU64::new(0),
}
}
pub fn snapshot(&self) -> QueueStatsSnapshot {
QueueStatsSnapshot {
enqueued: self.enqueued.load(Ordering::Relaxed),
dequeued: self.dequeued.load(Ordering::Relaxed),
completed: self.completed.load(Ordering::Relaxed),
failed: self.failed.load(Ordering::Relaxed),
expired: self.expired.load(Ordering::Relaxed),
rejected: self.rejected.load(Ordering::Relaxed),
retried: self.retried.load(Ordering::Relaxed),
stolen: self.stolen.load(Ordering::Relaxed),
cancelled: self.cancelled.load(Ordering::Relaxed),
pending: self.pending.load(Ordering::Relaxed),
current_depth: self.current_depth.load(Ordering::Relaxed),
max_depth: self.max_depth.load(Ordering::Relaxed),
}
}
}
impl Default for QueueStats {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueueStatsSnapshot {
pub enqueued: u64,
pub dequeued: u64,
pub completed: u64,
pub failed: u64,
pub expired: u64,
pub rejected: u64,
pub retried: u64,
pub stolen: u64,
pub cancelled: u64,
pub pending: u64,
pub current_depth: u64,
pub max_depth: u64,
}
impl QueueStatsSnapshot {
pub fn to_json(&self) -> serde_json::Value {
serde_json::json!({
"enqueued": self.enqueued,
"dequeued": self.dequeued,
"completed": self.completed,
"failed": self.failed,
"expired": self.expired,
"rejected": self.rejected,
"retried": self.retried,
"stolen": self.stolen,
"cancelled": self.cancelled,
"pending": self.pending,
"current_depth": self.current_depth,
"max_depth": self.max_depth,
"success_rate": if self.completed + self.failed > 0 {
self.completed as f64 / (self.completed + self.failed) as f64
} else {
1.0
},
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueueError {
QueueFull,
MaxRetriesExceeded,
TaskNotFound,
DependenciesNotMet,
}
impl std::fmt::Display for QueueError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
QueueError::QueueFull => write!(f, "Task queue is full"),
QueueError::MaxRetriesExceeded => write!(f, "Maximum retries exceeded"),
QueueError::TaskNotFound => write!(f, "Task not found"),
QueueError::DependenciesNotMet => write!(f, "Task dependencies not met"),
}
}
}
impl std::error::Error for QueueError {}
static TASK_QUEUES: OnceLock<TaskQueueRegistry> = OnceLock::new();
pub struct TaskQueueRegistry {
pub queries: TaskQueue,
pub maintenance: TaskQueue,
pub training: TaskQueue,
pub integrity: TaskQueue,
}
impl TaskQueueRegistry {
pub fn new() -> Self {
Self {
queries: TaskQueue::new(10000),
maintenance: TaskQueue::new(1000),
training: TaskQueue::new(100),
integrity: TaskQueue::new(500),
}
}
pub fn get_queue(&self, task_type: TaskType) -> &TaskQueue {
match task_type {
TaskType::Query | TaskType::Insert | TaskType::Delete => &self.queries,
TaskType::Maintenance | TaskType::IndexBuild | TaskType::StatsCollection => {
&self.maintenance
}
TaskType::Training => &self.training,
TaskType::Integrity => &self.integrity,
}
}
pub fn all_stats(&self) -> serde_json::Value {
serde_json::json!({
"queries": self.queries.stats().to_json(),
"maintenance": self.maintenance.stats().to_json(),
"training": self.training.stats().to_json(),
"integrity": self.integrity.stats().to_json(),
})
}
}
impl Default for TaskQueueRegistry {
fn default() -> Self {
Self::new()
}
}
pub fn get_task_queues() -> &'static TaskQueueRegistry {
TASK_QUEUES.get_or_init(TaskQueueRegistry::new)
}
pub fn init_task_queues() {
let _ = get_task_queues();
pgrx::log!("Task queues initialized");
}
fn current_epoch_ms() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as u64
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_task_creation() {
let task = Task::new(TaskType::Query, vec![1, 2, 3])
.with_priority(TaskPriority::High)
.with_collection(1);
assert_eq!(task.task_type, TaskType::Query);
assert_eq!(task.priority, TaskPriority::High);
assert_eq!(task.collection_id, Some(1));
assert!(!task.is_expired());
}
#[test]
fn test_task_expiry() {
let mut task = Task::new(TaskType::Query, vec![]);
task.deadline_ms = current_epoch_ms() - 1000; assert!(task.is_expired());
task.deadline_ms = current_epoch_ms() + 1000; assert!(!task.is_expired());
}
#[test]
fn test_task_retry() {
let mut task = Task::new(TaskType::Query, vec![]).with_max_retries(2);
assert!(task.can_retry());
task.increment_retry();
assert!(task.can_retry());
task.increment_retry();
assert!(!task.can_retry());
}
#[test]
fn test_queue_basic_operations() {
let queue = TaskQueue::new(100);
let task1 = Task::new(TaskType::Query, vec![1]).with_priority(TaskPriority::Low);
let task2 = Task::new(TaskType::Query, vec![2]).with_priority(TaskPriority::High);
queue.enqueue(task1).unwrap();
queue.enqueue(task2).unwrap();
assert_eq!(queue.len(), 2);
let dequeued = queue.dequeue().unwrap();
assert_eq!(dequeued.priority, TaskPriority::High);
}
#[test]
fn test_queue_full() {
let queue = TaskQueue::new(2);
let task = Task::new(TaskType::Query, vec![]);
assert!(queue.enqueue(task.clone()).is_ok());
assert!(queue.enqueue(task.clone()).is_ok());
assert_eq!(queue.enqueue(task.clone()), Err(QueueError::QueueFull));
}
#[test]
fn test_queue_dependencies() {
let queue = TaskQueue::new(100);
let task1_id = 1000;
let task2 = Task {
id: 2000,
dependencies: vec![task1_id],
..Task::new(TaskType::Query, vec![])
};
queue.enqueue(task2.clone()).unwrap();
queue.mark_completed(task1_id);
let dequeued = queue.dequeue().unwrap();
assert_eq!(dequeued.id, 2000);
}
#[test]
fn test_queue_cancel() {
let queue = TaskQueue::new(100);
let task = Task::new(TaskType::Query, vec![]);
let task_id = task.id;
queue.enqueue(task).unwrap();
assert_eq!(queue.len(), 1);
assert!(queue.cancel(task_id));
assert_eq!(queue.len(), 0);
}
#[test]
fn test_work_stealing() {
let queue = TaskQueue::new(100);
let worker0 = queue.add_worker();
let worker1 = queue.add_worker();
{
let queues = queue.worker_queues.read();
let mut w0_queue = queues[worker0].lock();
for i in 0..4 {
w0_queue.push(Task::new(TaskType::Query, vec![i as u8]));
}
}
let stolen = queue.dequeue_for_worker(worker1);
assert!(stolen.is_some());
let stats = queue.stats();
assert!(stats.stolen > 0);
}
#[test]
fn test_queue_stats() {
let queue = TaskQueue::new(100);
let task = Task::new(TaskType::Query, vec![]);
let task_id = task.id;
queue.enqueue(task).unwrap();
let _ = queue.dequeue();
queue.mark_completed(task_id);
let stats = queue.stats();
assert_eq!(stats.enqueued, 1);
assert_eq!(stats.dequeued, 1);
assert_eq!(stats.completed, 1);
}
}