use crate::executor::Dataset;
use crate::query_fingerprinting::{FingerprintConfig, QueryFingerprint, QueryFingerprinter};
use crate::system_load_monitor::AdaptiveConcurrencyController;
use anyhow::Result;
use scirs2_core::metrics::{Counter, Gauge, Timer};
use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant};
use tokio::task::JoinHandle;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub enum QueryPriority {
High = 3,
#[default]
Normal = 2,
Low = 1,
Background = 0,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum BatchMode {
Parallel,
Sequential,
#[default]
Optimized,
Adaptive,
}
#[derive(Debug, Clone)]
pub struct BatchConfig {
pub max_concurrent: usize,
pub memory_limit_mb: usize,
pub cpu_limit: f64,
pub mode: BatchMode,
pub enable_grouping: bool,
pub enable_caching: bool,
pub batch_timeout: Duration,
pub query_timeout: Duration,
pub fair_scheduling: bool,
}
impl Default for BatchConfig {
fn default() -> Self {
Self {
max_concurrent: num_cpus::get(),
memory_limit_mb: 4096,
cpu_limit: 0.8,
mode: BatchMode::default(),
enable_grouping: true,
enable_caching: true,
batch_timeout: Duration::from_secs(300), query_timeout: Duration::from_secs(60), fair_scheduling: true,
}
}
}
impl BatchConfig {
pub fn with_max_concurrent(mut self, max: usize) -> Self {
self.max_concurrent = max.max(1);
self
}
pub fn with_memory_limit_mb(mut self, limit: usize) -> Self {
self.memory_limit_mb = limit;
self
}
pub fn with_cpu_limit(mut self, limit: f64) -> Self {
self.cpu_limit = limit.clamp(0.0, 1.0);
self
}
pub fn with_mode(mut self, mode: BatchMode) -> Self {
self.mode = mode;
self
}
pub fn with_grouping(mut self, enabled: bool) -> Self {
self.enable_grouping = enabled;
self
}
pub fn with_caching(mut self, enabled: bool) -> Self {
self.enable_caching = enabled;
self
}
pub fn with_batch_timeout(mut self, timeout: Duration) -> Self {
self.batch_timeout = timeout;
self
}
pub fn with_query_timeout(mut self, timeout: Duration) -> Self {
self.query_timeout = timeout;
self
}
}
#[derive(Debug, Clone)]
pub struct BatchQuery {
pub id: String,
pub query: String,
pub priority: QueryPriority,
pub fingerprint: Option<QueryFingerprint>,
pub submitted_at: Instant,
pub started_at: Option<Instant>,
pub completed_at: Option<Instant>,
}
#[derive(Debug, Clone)]
pub struct BatchQueryResult {
pub id: String,
pub success: bool,
pub results: Option<String>, pub error: Option<String>,
pub duration: Duration,
pub result_count: usize,
}
#[derive(Debug, Clone)]
pub struct BatchStatistics {
pub total_queries: usize,
pub successful_queries: usize,
pub failed_queries: usize,
pub total_duration: Duration,
pub avg_duration: Duration,
pub min_duration: Duration,
pub max_duration: Duration,
pub total_results: usize,
pub throughput: f64,
pub peak_memory_mb: f64,
pub avg_cpu_usage: f64,
pub cached_queries: usize,
pub query_groups: usize,
}
impl BatchStatistics {
pub fn new() -> Self {
Self {
total_queries: 0,
successful_queries: 0,
failed_queries: 0,
total_duration: Duration::from_secs(0),
avg_duration: Duration::from_secs(0),
min_duration: Duration::MAX,
max_duration: Duration::from_secs(0),
total_results: 0,
throughput: 0.0,
peak_memory_mb: 0.0,
avg_cpu_usage: 0.0,
cached_queries: 0,
query_groups: 0,
}
}
pub fn calculate_derived(&mut self) {
if self.total_queries > 0 {
let total_secs = self.total_duration.as_secs_f64();
if total_secs > 0.0 {
self.throughput = self.total_queries as f64 / total_secs;
}
if self.successful_queries > 0 {
self.avg_duration = self.total_duration / self.successful_queries as u32;
}
}
}
pub fn success_rate(&self) -> f64 {
if self.total_queries == 0 {
return 0.0;
}
self.successful_queries as f64 / self.total_queries as f64
}
pub fn cache_hit_rate(&self) -> f64 {
if self.total_queries == 0 {
return 0.0;
}
self.cached_queries as f64 / self.total_queries as f64
}
}
impl Default for BatchStatistics {
fn default() -> Self {
Self::new()
}
}
pub struct QueryBatchExecutor {
config: BatchConfig,
queue: Arc<Mutex<VecDeque<BatchQuery>>>,
#[allow(dead_code)]
results: Arc<RwLock<HashMap<String, BatchQueryResult>>>,
fingerprinter: QueryFingerprinter,
cache: Arc<RwLock<HashMap<String, String>>>, stats: Arc<RwLock<BatchStatistics>>,
queries_executed: Counter,
queries_failed: Counter,
batch_duration: Timer,
active_queries: Gauge,
}
impl QueryBatchExecutor {
pub fn new(config: BatchConfig) -> Self {
Self {
config,
queue: Arc::new(Mutex::new(VecDeque::new())),
results: Arc::new(RwLock::new(HashMap::new())),
fingerprinter: QueryFingerprinter::new(FingerprintConfig::default()),
cache: Arc::new(RwLock::new(HashMap::new())),
stats: Arc::new(RwLock::new(BatchStatistics::new())),
queries_executed: Counter::new("batch_queries_executed".to_string()),
queries_failed: Counter::new("batch_queries_failed".to_string()),
batch_duration: Timer::new("batch_execution_duration".to_string()),
active_queries: Gauge::new("batch_active_queries".to_string()),
}
}
pub fn add_query(&self, query: impl Into<String>, priority: QueryPriority) -> Result<String> {
let query = query.into();
let id = format!("query_{}", uuid::Uuid::new_v4());
let fingerprint = if self.config.enable_grouping {
Some(self.fingerprinter.fingerprint(&query)?)
} else {
None
};
let batch_query = BatchQuery {
id: id.clone(),
query,
priority,
fingerprint,
submitted_at: Instant::now(),
started_at: None,
completed_at: None,
};
let mut queue = self.queue.lock().expect("lock poisoned");
if self.config.fair_scheduling {
let insert_pos = queue
.iter()
.rposition(|q| q.priority >= priority)
.map(|pos| pos + 1)
.unwrap_or(0);
queue.insert(insert_pos, batch_query);
} else {
let insert_pos = queue
.iter()
.position(|q| q.priority < priority)
.unwrap_or(queue.len());
queue.insert(insert_pos, batch_query);
}
Ok(id)
}
pub fn add_queries(&self, queries: Vec<(String, QueryPriority)>) -> Result<Vec<String>> {
queries
.into_iter()
.map(|(q, p)| self.add_query(q, p))
.collect()
}
pub fn queue_size(&self) -> usize {
self.queue.lock().expect("lock poisoned").len()
}
pub fn clear_queue(&self) {
self.queue.lock().expect("lock poisoned").clear();
}
pub fn statistics(&self) -> BatchStatistics {
self.stats.read().expect("lock poisoned").clone()
}
pub async fn execute_batch_async<D: Dataset + Send + Sync + 'static>(
&self,
dataset: Arc<D>,
) -> Result<Vec<BatchQueryResult>> {
let start_time = Instant::now();
let queries: Vec<BatchQuery> = {
let mut queue = self.queue.lock().expect("lock poisoned");
queue.drain(..).collect()
};
if queries.is_empty() {
return Ok(Vec::new());
}
{
let mut stats = self.stats.write().expect("lock poisoned");
stats.total_queries = queries.len();
}
let results = match self.config.mode {
BatchMode::Parallel => self.execute_parallel(queries, dataset).await?,
BatchMode::Sequential => self.execute_sequential(queries, dataset).await?,
BatchMode::Optimized => self.execute_optimized(queries, dataset).await?,
BatchMode::Adaptive => self.execute_adaptive(queries, dataset).await?,
};
let duration = start_time.elapsed();
{
let mut stats = self.stats.write().expect("lock poisoned");
stats.total_duration = duration;
stats.calculate_derived();
}
self.batch_duration.observe(duration);
Ok(results)
}
async fn execute_parallel<D: Dataset + Send + Sync + 'static>(
&self,
queries: Vec<BatchQuery>,
dataset: Arc<D>,
) -> Result<Vec<BatchQueryResult>> {
let semaphore = Arc::new(tokio::sync::Semaphore::new(self.config.max_concurrent));
let mut handles: Vec<JoinHandle<BatchQueryResult>> = Vec::new();
for query in queries {
let permit = semaphore.clone().acquire_owned().await?;
let dataset = dataset.clone();
let timeout = self.config.query_timeout;
let cache = self.cache.clone();
let enable_caching = self.config.enable_caching;
let fingerprint = query.fingerprint.clone();
self.active_queries.inc();
let handle = tokio::spawn(async move {
let result = Self::execute_single_query(
query,
dataset,
timeout,
cache,
enable_caching,
fingerprint,
)
.await;
drop(permit);
result
});
handles.push(handle);
}
let mut results = Vec::new();
for handle in handles {
match handle.await {
Ok(result) => {
self.update_stats(&result);
results.push(result);
}
Err(e) => {
eprintln!("Task failed: {}", e);
self.queries_failed.inc();
}
}
self.active_queries.dec();
}
Ok(results)
}
async fn execute_sequential<D: Dataset + Send + Sync + 'static>(
&self,
queries: Vec<BatchQuery>,
dataset: Arc<D>,
) -> Result<Vec<BatchQueryResult>> {
let mut results = Vec::new();
for query in queries {
self.active_queries.inc();
let result = Self::execute_single_query(
query,
dataset.clone(),
self.config.query_timeout,
self.cache.clone(),
self.config.enable_caching,
None,
)
.await;
self.update_stats(&result);
results.push(result);
self.active_queries.dec();
}
Ok(results)
}
async fn execute_optimized<D: Dataset + Send + Sync + 'static>(
&self,
queries: Vec<BatchQuery>,
dataset: Arc<D>,
) -> Result<Vec<BatchQueryResult>> {
let mut groups: HashMap<String, Vec<BatchQuery>> = HashMap::new();
for query in queries {
let key = query
.fingerprint
.as_ref()
.map(|f| f.hash.clone())
.unwrap_or_else(|| query.id.clone());
groups.entry(key).or_default().push(query);
}
{
let mut stats = self.stats.write().expect("lock poisoned");
stats.query_groups = groups.len();
}
let mut all_results = Vec::new();
for (_key, group) in groups {
let group_results = self.execute_parallel(group, dataset.clone()).await?;
all_results.extend(group_results);
}
Ok(all_results)
}
async fn execute_adaptive<D: Dataset + Send + Sync + 'static>(
&self,
queries: Vec<BatchQuery>,
dataset: Arc<D>,
) -> Result<Vec<BatchQueryResult>> {
use tokio::sync::Semaphore;
let controller = Arc::new(
AdaptiveConcurrencyController::new(self.config.max_concurrent)
.with_thresholds(0.75, 0.40) .with_adjustment_interval(Duration::from_secs(5)),
);
let initial_permits = controller.current_concurrency();
let semaphore = Arc::new(Semaphore::new(initial_permits));
let controller_clone = Arc::clone(&controller);
let semaphore_clone = Arc::clone(&semaphore);
let adjustment_task = tokio::spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(5)).await;
controller_clone.update_concurrency();
let new_concurrency = controller_clone.current_concurrency();
let current_permits = semaphore_clone.available_permits();
if new_concurrency > current_permits {
semaphore_clone.add_permits(new_concurrency - current_permits);
}
}
});
let mut tasks = Vec::new();
for query in queries {
let dataset_clone = Arc::clone(&dataset);
let cache_clone = Arc::clone(&self.cache);
let timeout = self.config.query_timeout;
let enable_caching = self.config.enable_caching;
let fingerprint = query.fingerprint.clone();
let semaphore_clone = Arc::clone(&semaphore);
let task = tokio::spawn(async move {
let _permit = semaphore_clone.acquire().await.expect("Semaphore closed");
Self::execute_single_query(
query,
dataset_clone,
timeout,
cache_clone,
enable_caching,
fingerprint,
)
.await
});
tasks.push(task);
}
let mut results = Vec::with_capacity(tasks.len());
for task in tasks {
match task.await {
Ok(result) => results.push(result),
Err(e) => {
eprintln!("Task execution error: {}", e);
}
}
}
adjustment_task.abort();
Ok(results)
}
async fn execute_single_query<D: Dataset + Send + Sync + 'static>(
mut query: BatchQuery,
_dataset: Arc<D>,
timeout: Duration,
cache: Arc<RwLock<HashMap<String, String>>>,
enable_caching: bool,
_fingerprint: Option<QueryFingerprint>,
) -> BatchQueryResult {
query.started_at = Some(Instant::now());
let start = Instant::now();
if enable_caching {
if let Some(fp) = &query.fingerprint {
if let Some(cached) = cache.read().expect("lock poisoned").get(&fp.hash) {
query.completed_at = Some(Instant::now());
return BatchQueryResult {
id: query.id,
success: true,
results: Some(cached.clone()),
error: None,
duration: start.elapsed(),
result_count: cached.lines().count(),
};
}
}
}
let result = tokio::time::timeout(timeout, async {
tokio::time::sleep(Duration::from_millis(10)).await;
Ok::<String, anyhow::Error>(format!("Results for: {}", query.query))
})
.await;
query.completed_at = Some(Instant::now());
let duration = start.elapsed();
match result {
Ok(Ok(results)) => {
if enable_caching {
if let Some(fp) = &query.fingerprint {
cache
.write()
.expect("lock poisoned")
.insert(fp.hash.clone(), results.clone());
}
}
BatchQueryResult {
id: query.id,
success: true,
results: Some(results.clone()),
error: None,
duration,
result_count: results.lines().count(),
}
}
Ok(Err(e)) => BatchQueryResult {
id: query.id,
success: false,
results: None,
error: Some(e.to_string()),
duration,
result_count: 0,
},
Err(_) => BatchQueryResult {
id: query.id,
success: false,
results: None,
error: Some("Query timeout".to_string()),
duration,
result_count: 0,
},
}
}
fn update_stats(&self, result: &BatchQueryResult) {
let mut stats = self.stats.write().expect("lock poisoned");
if result.success {
stats.successful_queries += 1;
self.queries_executed.inc();
stats.total_results += result.result_count;
if result.duration < stats.min_duration {
stats.min_duration = result.duration;
}
if result.duration > stats.max_duration {
stats.max_duration = result.duration;
}
} else {
stats.failed_queries += 1;
self.queries_failed.inc();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_batch_config_builder() {
let config = BatchConfig::default()
.with_max_concurrent(32)
.with_memory_limit_mb(8192)
.with_cpu_limit(0.9)
.with_mode(BatchMode::Parallel);
assert_eq!(config.max_concurrent, 32);
assert_eq!(config.memory_limit_mb, 8192);
assert_eq!(config.cpu_limit, 0.9);
assert_eq!(config.mode, BatchMode::Parallel);
}
#[test]
fn test_priority_ordering() {
assert!(QueryPriority::High > QueryPriority::Normal);
assert!(QueryPriority::Normal > QueryPriority::Low);
assert!(QueryPriority::Low > QueryPriority::Background);
}
#[test]
fn test_batch_statistics() {
let mut stats = BatchStatistics::new();
stats.total_queries = 100;
stats.successful_queries = 95;
stats.failed_queries = 5;
stats.total_duration = Duration::from_secs(10);
stats.cached_queries = 20;
stats.calculate_derived();
assert_eq!(stats.success_rate(), 0.95);
assert_eq!(stats.cache_hit_rate(), 0.2);
assert_eq!(stats.throughput, 10.0); }
#[test]
fn test_add_query() {
let executor = QueryBatchExecutor::new(BatchConfig::default());
let id1 = executor
.add_query("SELECT * WHERE { ?s ?p ?o }", QueryPriority::Normal)
.unwrap();
let id2 = executor
.add_query("ASK { ?s a :Person }", QueryPriority::High)
.unwrap();
assert_eq!(executor.queue_size(), 2);
assert_ne!(id1, id2);
let queue = executor.queue.lock().expect("lock should not be poisoned");
assert_eq!(queue[0].priority, QueryPriority::High);
assert_eq!(queue[1].priority, QueryPriority::Normal);
}
#[test]
fn test_add_multiple_queries() {
let executor = QueryBatchExecutor::new(BatchConfig::default());
let queries = vec![
(
"SELECT ?s WHERE { ?s ?p ?o }".to_string(),
QueryPriority::Normal,
),
(
"SELECT ?p WHERE { ?s ?p ?o }".to_string(),
QueryPriority::Low,
),
(
"SELECT ?o WHERE { ?s ?p ?o }".to_string(),
QueryPriority::High,
),
];
let ids = executor.add_queries(queries).unwrap();
assert_eq!(ids.len(), 3);
assert_eq!(executor.queue_size(), 3);
}
#[test]
fn test_clear_queue() {
let executor = QueryBatchExecutor::new(BatchConfig::default());
executor
.add_query("SELECT * WHERE { ?s ?p ?o }", QueryPriority::Normal)
.unwrap();
executor
.add_query("ASK { ?s a :Person }", QueryPriority::High)
.unwrap();
assert_eq!(executor.queue_size(), 2);
executor.clear_queue();
assert_eq!(executor.queue_size(), 0);
}
#[test]
fn test_batch_modes() {
let modes = vec![
BatchMode::Parallel,
BatchMode::Sequential,
BatchMode::Optimized,
BatchMode::Adaptive,
];
for mode in modes {
let config = BatchConfig::default().with_mode(mode);
assert_eq!(config.mode, mode);
}
}
#[test]
fn test_fair_scheduling() {
let config = BatchConfig {
fair_scheduling: true,
..Default::default()
};
let executor = QueryBatchExecutor::new(config);
executor.add_query("Q1", QueryPriority::Normal).unwrap();
executor.add_query("Q2", QueryPriority::High).unwrap();
executor.add_query("Q3", QueryPriority::Normal).unwrap();
executor.add_query("Q4", QueryPriority::High).unwrap();
let queue = executor.queue.lock().expect("lock should not be poisoned");
assert_eq!(queue[0].priority, QueryPriority::High);
assert_eq!(queue[1].priority, QueryPriority::High);
assert_eq!(queue[2].priority, QueryPriority::Normal);
assert_eq!(queue[3].priority, QueryPriority::Normal);
}
#[test]
fn test_config_limits() {
let config = BatchConfig::default()
.with_max_concurrent(0) .with_cpu_limit(1.5);
assert_eq!(config.max_concurrent, 1);
assert_eq!(config.cpu_limit, 1.0);
}
#[test]
fn test_batch_query_timing() {
let query = BatchQuery {
id: "test".to_string(),
query: "SELECT * WHERE { ?s ?p ?o }".to_string(),
priority: QueryPriority::Normal,
fingerprint: None,
submitted_at: Instant::now(),
started_at: None,
completed_at: None,
};
assert!(query.started_at.is_none());
assert!(query.completed_at.is_none());
}
}