use crate::error::{FusekiError, FusekiResult};
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::collections::{BinaryHeap, VecDeque};
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering as AtomicOrdering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Notify, RwLock, Semaphore};
use tokio::task::JoinHandle;
use tracing::{debug, error, info, instrument, warn};
use uuid::Uuid;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum Priority {
Low = 0,
Normal = 1,
High = 2,
Critical = 3,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConcurrencyConfig {
pub max_global_concurrent: usize,
pub max_per_dataset_concurrent: usize,
pub max_per_user_concurrent: usize,
pub enable_work_stealing: bool,
pub max_queue_size: usize,
pub queue_timeout_secs: u64,
pub enable_load_shedding: bool,
pub load_shedding_threshold: f64,
pub worker_threads: usize,
pub enable_fair_scheduling: bool,
}
impl Default for ConcurrencyConfig {
fn default() -> Self {
ConcurrencyConfig {
max_global_concurrent: 200,
max_per_dataset_concurrent: 50,
max_per_user_concurrent: 10,
enable_work_stealing: true,
max_queue_size: 10000,
queue_timeout_secs: 300, enable_load_shedding: true,
load_shedding_threshold: 0.85,
worker_threads: num_cpus::get(),
enable_fair_scheduling: true,
}
}
}
#[derive(Debug)]
pub struct QueryRequest {
pub id: String,
pub dataset: String,
pub user_id: Option<String>,
pub query: String,
pub priority: Priority,
pub estimated_time_ms: u64,
pub estimated_memory_mb: u64,
pub queued_at: Instant,
pub timeout: Duration,
}
impl PartialEq for QueryRequest {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
}
}
impl Eq for QueryRequest {}
impl PartialOrd for QueryRequest {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for QueryRequest {
fn cmp(&self, other: &Self) -> Ordering {
match self.priority.cmp(&other.priority).reverse() {
Ordering::Equal => {
self.queued_at.cmp(&other.queued_at)
}
other_order => other_order,
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct WorkerStats {
pub worker_id: usize,
pub tasks_executed: u64,
pub tasks_stolen: u64,
pub steal_attempts: u64,
pub current_queue_size: usize,
pub total_execution_time_ms: u64,
pub idle_time_ms: u64,
}
struct Worker {
id: usize,
local_queue: Arc<RwLock<VecDeque<QueryRequest>>>,
stats: Arc<RwLock<WorkerStats>>,
notify: Arc<Notify>,
}
impl Worker {
fn new(id: usize) -> Self {
Worker {
id,
local_queue: Arc::new(RwLock::new(VecDeque::new())),
stats: Arc::new(RwLock::new(WorkerStats {
worker_id: id,
tasks_executed: 0,
tasks_stolen: 0,
steal_attempts: 0,
current_queue_size: 0,
total_execution_time_ms: 0,
idle_time_ms: 0,
})),
notify: Arc::new(Notify::new()),
}
}
async fn push(&self, request: QueryRequest) {
let mut queue = self.local_queue.write().await;
queue.push_back(request);
self.notify.notify_one();
}
async fn pop(&self) -> Option<QueryRequest> {
let mut queue = self.local_queue.write().await;
let request = queue.pop_front();
if request.is_some() {
let mut stats = self.stats.write().await;
stats.current_queue_size = queue.len();
}
request
}
async fn steal(&self) -> Option<QueryRequest> {
let mut queue = self.local_queue.write().await;
let request = queue.pop_back(); if request.is_some() {
let mut stats = self.stats.write().await;
stats.current_queue_size = queue.len();
}
request
}
async fn queue_size(&self) -> usize {
self.local_queue.read().await.len()
}
}
#[derive(Debug, Clone, Serialize)]
pub struct ConcurrencyStats {
pub active_requests: usize,
pub queued_requests: usize,
pub total_requests: u64,
pub completed_requests: u64,
pub failed_requests: u64,
pub rejected_requests: u64,
pub timed_out_requests: u64,
pub average_wait_time_ms: f64,
pub average_execution_time_ms: f64,
pub current_load: f64,
pub worker_stats: Vec<WorkerStats>,
}
pub struct ConcurrencyManager {
config: ConcurrencyConfig,
global_semaphore: Arc<Semaphore>,
dataset_semaphores: Arc<DashMap<String, Arc<Semaphore>>>,
user_semaphores: Arc<DashMap<String, Arc<Semaphore>>>,
priority_queue: Arc<RwLock<BinaryHeap<QueryRequest>>>,
workers: Arc<Vec<Worker>>,
active_requests: Arc<DashMap<String, Instant>>,
stats: Arc<RwLock<ConcurrencyStats>>,
total_requests: Arc<AtomicU64>,
completed_requests: Arc<AtomicU64>,
failed_requests: Arc<AtomicU64>,
rejected_requests: Arc<AtomicU64>,
timed_out_requests: Arc<AtomicU64>,
shutdown: Arc<tokio::sync::watch::Sender<bool>>,
}
impl ConcurrencyManager {
pub fn new(config: ConcurrencyConfig) -> Arc<Self> {
let global_semaphore = Arc::new(Semaphore::new(config.max_global_concurrent));
let dataset_semaphores = Arc::new(DashMap::new());
let user_semaphores = Arc::new(DashMap::new());
let priority_queue = Arc::new(RwLock::new(BinaryHeap::new()));
let mut workers = Vec::new();
for i in 0..config.worker_threads {
workers.push(Worker::new(i));
}
let workers = Arc::new(workers);
let active_requests = Arc::new(DashMap::new());
let stats = Arc::new(RwLock::new(ConcurrencyStats {
active_requests: 0,
queued_requests: 0,
total_requests: 0,
completed_requests: 0,
failed_requests: 0,
rejected_requests: 0,
timed_out_requests: 0,
average_wait_time_ms: 0.0,
average_execution_time_ms: 0.0,
current_load: 0.0,
worker_stats: Vec::new(),
}));
let (shutdown_tx, _) = tokio::sync::watch::channel(false);
let manager = Arc::new(ConcurrencyManager {
config,
global_semaphore,
dataset_semaphores,
user_semaphores,
priority_queue,
workers,
active_requests,
stats,
total_requests: Arc::new(AtomicU64::new(0)),
completed_requests: Arc::new(AtomicU64::new(0)),
failed_requests: Arc::new(AtomicU64::new(0)),
rejected_requests: Arc::new(AtomicU64::new(0)),
timed_out_requests: Arc::new(AtomicU64::new(0)),
shutdown: Arc::new(shutdown_tx),
});
if manager.config.enable_work_stealing {
manager.clone().start_scheduler();
}
manager.clone().start_monitoring();
info!(
"Concurrency manager initialized with {} workers and max {} concurrent requests",
manager.config.worker_threads, manager.config.max_global_concurrent
);
manager
}
#[instrument(skip(self, request))]
pub async fn submit(&self, request: QueryRequest) -> FusekiResult<QueryPermit> {
self.total_requests.fetch_add(1, AtomicOrdering::Relaxed);
if self.config.enable_load_shedding {
let load = self.calculate_current_load().await;
if load > self.config.load_shedding_threshold {
self.rejected_requests.fetch_add(1, AtomicOrdering::Relaxed);
return Err(FusekiError::service_unavailable(
"Server is overloaded, request rejected",
));
}
}
let queued = self.priority_queue.read().await.len();
if queued >= self.config.max_queue_size {
self.rejected_requests.fetch_add(1, AtomicOrdering::Relaxed);
return Err(FusekiError::service_unavailable("Request queue is full"));
}
let request_id = request.id.clone();
let dataset = request.dataset.clone();
let user_id = request.user_id.clone();
let queued_at = request.queued_at;
{
let mut queue = self.priority_queue.write().await;
queue.push(request);
}
let global_permit = self
.acquire_with_timeout(
&self.global_semaphore,
Duration::from_secs(self.config.queue_timeout_secs),
)
.await?;
let dataset_permit = self.acquire_dataset_permit(&dataset).await?;
let user_permit = if let Some(user) = &user_id {
Some(self.acquire_user_permit(user).await?)
} else {
None
};
self.active_requests
.insert(request_id.clone(), Instant::now());
let wait_time = queued_at.elapsed();
debug!(
"Query {} acquired permits after {:.2}ms",
request_id,
wait_time.as_millis()
);
Ok(QueryPermit {
request_id,
_global_permit: global_permit,
_dataset_permit: dataset_permit,
_user_permit: user_permit,
started_at: Instant::now(),
active_requests: Arc::clone(&self.active_requests),
completed_requests: Arc::clone(&self.completed_requests),
failed_requests: Arc::clone(&self.failed_requests),
completed_successfully: Arc::new(AtomicBool::new(false)),
})
}
async fn acquire_with_timeout(
&self,
semaphore: &Arc<Semaphore>,
timeout: Duration,
) -> FusekiResult<tokio::sync::OwnedSemaphorePermit> {
tokio::time::timeout(timeout, semaphore.clone().acquire_owned())
.await
.map_err(|_| {
self.timed_out_requests
.fetch_add(1, AtomicOrdering::Relaxed);
FusekiError::request_timeout("Request timed out waiting for execution slot")
})?
.map_err(|_| FusekiError::server_error("Semaphore closed"))
}
async fn acquire_dataset_permit(
&self,
dataset: &str,
) -> FusekiResult<tokio::sync::OwnedSemaphorePermit> {
let semaphore = self
.dataset_semaphores
.entry(dataset.to_string())
.or_insert_with(|| Arc::new(Semaphore::new(self.config.max_per_dataset_concurrent)))
.clone();
self.acquire_with_timeout(
&semaphore,
Duration::from_secs(self.config.queue_timeout_secs),
)
.await
}
async fn acquire_user_permit(
&self,
user_id: &str,
) -> FusekiResult<tokio::sync::OwnedSemaphorePermit> {
let semaphore = self
.user_semaphores
.entry(user_id.to_string())
.or_insert_with(|| Arc::new(Semaphore::new(self.config.max_per_user_concurrent)))
.clone();
self.acquire_with_timeout(
&semaphore,
Duration::from_secs(self.config.queue_timeout_secs),
)
.await
}
fn start_scheduler(self: Arc<Self>) {
let manager = self.clone();
tokio::spawn(async move {
let mut shutdown_rx = manager.shutdown.subscribe();
loop {
tokio::select! {
_ = shutdown_rx.changed() => {
if *shutdown_rx.borrow() {
break;
}
}
_ = tokio::time::sleep(Duration::from_millis(10)) => {
manager.distribute_work().await;
if manager.config.enable_work_stealing {
manager.perform_work_stealing().await;
}
}
}
}
});
}
async fn distribute_work(&self) {
let mut queue = self.priority_queue.write().await;
while let Some(request) = queue.pop() {
let worker_idx = self.find_least_loaded_worker().await;
if let Some(worker) = self.workers.get(worker_idx) {
worker.push(request).await;
}
}
}
async fn find_least_loaded_worker(&self) -> usize {
let mut min_size = usize::MAX;
let mut min_idx = 0;
for (idx, worker) in self.workers.iter().enumerate() {
let size = worker.queue_size().await;
if size < min_size {
min_size = size;
min_idx = idx;
}
}
min_idx
}
async fn perform_work_stealing(&self) {
let num_workers = self.workers.len();
if num_workers < 2 {
return;
}
for thief_idx in 0..num_workers {
let thief = &self.workers[thief_idx];
let thief_size = thief.queue_size().await;
if thief_size < 2 {
let victim_idx = (thief_idx + 1) % num_workers;
if victim_idx != thief_idx {
if let Some(victim) = self.workers.get(victim_idx) {
let victim_size = victim.queue_size().await;
if victim_size > 4 {
if let Some(stolen_task) = victim.steal().await {
thief.push(stolen_task).await;
let mut stats = thief.stats.write().await;
stats.tasks_stolen += 1;
stats.steal_attempts += 1;
}
}
}
}
}
}
}
fn start_monitoring(self: Arc<Self>) {
let manager = self.clone();
tokio::spawn(async move {
let mut shutdown_rx = manager.shutdown.subscribe();
loop {
tokio::select! {
_ = shutdown_rx.changed() => {
if *shutdown_rx.borrow() {
break;
}
}
_ = tokio::time::sleep(Duration::from_secs(5)) => {
manager.update_statistics().await;
}
}
}
});
}
async fn update_statistics(&self) {
let active = self.active_requests.len();
let queued = self.priority_queue.read().await.len();
let total = self.total_requests.load(AtomicOrdering::Relaxed);
let completed = self.completed_requests.load(AtomicOrdering::Relaxed);
let failed = self.failed_requests.load(AtomicOrdering::Relaxed);
let rejected = self.rejected_requests.load(AtomicOrdering::Relaxed);
let timed_out = self.timed_out_requests.load(AtomicOrdering::Relaxed);
let load = self.calculate_current_load().await;
let mut worker_stats = Vec::new();
for worker in self.workers.iter() {
let stats = worker.stats.read().await.clone();
worker_stats.push(stats);
}
let mut stats = self.stats.write().await;
stats.active_requests = active;
stats.queued_requests = queued;
stats.total_requests = total;
stats.completed_requests = completed;
stats.failed_requests = failed;
stats.rejected_requests = rejected;
stats.timed_out_requests = timed_out;
stats.current_load = load;
stats.worker_stats = worker_stats;
}
async fn calculate_current_load(&self) -> f64 {
let active = self.active_requests.len();
let max_concurrent = self.config.max_global_concurrent;
if max_concurrent == 0 {
return 0.0;
}
(active as f64) / (max_concurrent as f64)
}
pub async fn get_stats(&self) -> ConcurrencyStats {
self.update_statistics().await;
self.stats.read().await.clone()
}
fn mark_completed(&self, request_id: &str, success: bool) {
self.active_requests.remove(request_id);
if success {
self.completed_requests
.fetch_add(1, AtomicOrdering::Relaxed);
} else {
self.failed_requests.fetch_add(1, AtomicOrdering::Relaxed);
}
}
pub async fn shutdown(&self) {
info!("Shutting down concurrency manager");
let _ = self.shutdown.send(true);
}
}
pub struct QueryPermit {
request_id: String,
_global_permit: tokio::sync::OwnedSemaphorePermit,
_dataset_permit: tokio::sync::OwnedSemaphorePermit,
_user_permit: Option<tokio::sync::OwnedSemaphorePermit>,
started_at: Instant,
active_requests: Arc<DashMap<String, Instant>>,
completed_requests: Arc<AtomicU64>,
failed_requests: Arc<AtomicU64>,
completed_successfully: Arc<AtomicBool>,
}
impl QueryPermit {
pub fn request_id(&self) -> &str {
&self.request_id
}
pub fn elapsed(&self) -> Duration {
self.started_at.elapsed()
}
pub fn complete(self) {
let elapsed = self.elapsed();
debug!(
"Query {} completed in {:.2}ms",
self.request_id,
elapsed.as_millis()
);
self.completed_successfully
.store(true, AtomicOrdering::Release);
}
pub fn fail(self) {
let elapsed = self.elapsed();
warn!(
"Query {} failed after {:.2}ms",
self.request_id,
elapsed.as_millis()
);
}
}
impl Drop for QueryPermit {
fn drop(&mut self) {
self.active_requests.remove(&self.request_id);
let success = self.completed_successfully.load(AtomicOrdering::Acquire);
if success {
self.completed_requests
.fetch_add(1, AtomicOrdering::Relaxed);
} else {
self.failed_requests.fetch_add(1, AtomicOrdering::Relaxed);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_concurrency_manager_creation() {
let config = ConcurrencyConfig::default();
let manager = ConcurrencyManager::new(config);
let stats = manager.get_stats().await;
assert_eq!(stats.active_requests, 0);
assert_eq!(stats.queued_requests, 0);
}
#[tokio::test]
async fn test_request_submission() {
let config = ConcurrencyConfig {
max_global_concurrent: 10,
..Default::default()
};
let manager = ConcurrencyManager::new(config);
let request = QueryRequest {
id: Uuid::new_v4().to_string(),
dataset: "test".to_string(),
user_id: Some("user1".to_string()),
query: "SELECT * WHERE { ?s ?p ?o }".to_string(),
priority: Priority::Normal,
estimated_time_ms: 100,
estimated_memory_mb: 10,
queued_at: Instant::now(),
timeout: Duration::from_secs(30),
};
let permit = manager.submit(request).await;
assert!(permit.is_ok());
let permit = permit.unwrap();
assert!(!permit.request_id().is_empty());
permit.complete();
}
#[tokio::test]
async fn test_load_shedding() {
let config = ConcurrencyConfig {
max_global_concurrent: 2,
enable_load_shedding: true,
load_shedding_threshold: 0.5,
..Default::default()
};
let manager = ConcurrencyManager::new(config);
let mut permits = Vec::new();
for i in 0..2 {
let request = QueryRequest {
id: format!("req{}", i),
dataset: "test".to_string(),
user_id: None,
query: "SELECT * WHERE { ?s ?p ?o }".to_string(),
priority: Priority::Normal,
estimated_time_ms: 100,
estimated_memory_mb: 10,
queued_at: Instant::now(),
timeout: Duration::from_secs(30),
};
if let Ok(permit) = manager.submit(request).await {
permits.push(permit);
}
}
let request = QueryRequest {
id: "overflow".to_string(),
dataset: "test".to_string(),
user_id: None,
query: "SELECT * WHERE { ?s ?p ?o }".to_string(),
priority: Priority::Normal,
estimated_time_ms: 100,
estimated_memory_mb: 10,
queued_at: Instant::now(),
timeout: Duration::from_secs(30),
};
let result = manager.submit(request).await;
assert!(result.is_ok() || result.is_err());
}
}