use anyhow::{anyhow, Result};
use scirs2_core::metrics::{Counter, Gauge, Histogram};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{RwLock, Semaphore};
use tokio::task::JoinSet;
#[derive(Debug, Clone)]
pub struct ParallelResolutionConfig {
pub enabled: bool,
pub max_concurrency: usize,
pub min_fields_for_parallel: usize,
pub adaptive_concurrency: bool,
pub enable_work_stealing: bool,
pub max_queue_depth: usize,
pub field_timeout: Duration,
}
impl Default for ParallelResolutionConfig {
fn default() -> Self {
Self {
enabled: true,
max_concurrency: num_cpus::get() * 2,
min_fields_for_parallel: 3,
adaptive_concurrency: true,
enable_work_stealing: true,
max_queue_depth: 1000,
field_timeout: Duration::from_secs(30),
}
}
}
#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
pub struct FieldId {
pub parent_type: String,
pub field_name: String,
pub alias: Option<String>,
pub path: Vec<String>,
}
impl FieldId {
pub fn new(parent_type: String, field_name: String) -> Self {
Self {
parent_type,
field_name,
alias: None,
path: Vec::new(),
}
}
pub fn with_alias(mut self, alias: String) -> Self {
self.alias = Some(alias);
self
}
pub fn with_path(mut self, path: Vec<String>) -> Self {
self.path = path;
self
}
pub fn effective_name(&self) -> &str {
self.alias.as_deref().unwrap_or(&self.field_name)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FieldDependency {
DataDependency {
depends_on: FieldId,
reason: String,
},
ContextDependency {
required_context: Vec<String>,
},
Independent,
}
pub struct FieldResolutionTask<T> {
pub field_id: FieldId,
pub resolver: Arc<dyn Fn() -> Result<T> + Send + Sync>,
pub dependencies: Vec<FieldDependency>,
pub estimated_cost: f64,
pub priority: i32,
}
#[derive(Debug, Clone)]
pub struct FieldResolutionResult<T> {
pub field_id: FieldId,
pub result: Result<T, String>,
pub execution_time: Duration,
#[allow(dead_code)]
pub resolved_at: Instant,
}
#[derive(Debug)]
pub struct DependencyGraph {
dependencies: HashMap<FieldId, Vec<FieldDependency>>,
resolved: RwLock<HashSet<FieldId>>,
in_progress: RwLock<HashSet<FieldId>>,
}
impl DependencyGraph {
pub fn new() -> Self {
Self {
dependencies: HashMap::new(),
resolved: RwLock::new(HashSet::new()),
in_progress: RwLock::new(HashSet::new()),
}
}
pub fn with_dependencies(mut self, deps: HashMap<FieldId, Vec<FieldDependency>>) -> Self {
self.dependencies = deps;
self
}
pub async fn can_resolve(&self, field_id: &FieldId) -> bool {
let resolved = self.resolved.read().await;
let in_progress = self.in_progress.read().await;
if resolved.contains(field_id) || in_progress.contains(field_id) {
return false;
}
if let Some(deps) = self.dependencies.get(field_id) {
for dep in deps {
match dep {
FieldDependency::DataDependency { depends_on, .. } => {
if !resolved.contains(depends_on) {
return false;
}
}
FieldDependency::ContextDependency { .. } => {
}
FieldDependency::Independent => {}
}
}
}
true
}
pub async fn mark_in_progress(&self, field_id: FieldId) {
let mut in_progress = self.in_progress.write().await;
in_progress.insert(field_id);
}
pub async fn mark_resolved(&self, field_id: FieldId) {
let mut resolved = self.resolved.write().await;
let mut in_progress = self.in_progress.write().await;
in_progress.remove(&field_id);
resolved.insert(field_id);
}
pub async fn get_ready_fields(&self) -> Vec<FieldId> {
let mut ready = Vec::new();
for field_id in self.dependencies.keys() {
if self.can_resolve(field_id).await {
ready.push(field_id.clone());
}
}
ready
}
pub fn analyze_dependencies(
fields: &[FieldId],
field_metadata: &HashMap<FieldId, FieldMetadata>,
) -> HashMap<FieldId, Vec<FieldDependency>> {
let mut dependencies = HashMap::new();
for field in fields {
let mut field_deps = Vec::new();
if let Some(metadata) = field_metadata.get(field) {
for arg_source in &metadata.argument_sources {
if let Some(source_field) = arg_source.strip_prefix('$') {
let dep_field = fields
.iter()
.find(|f| f.field_name == source_field)
.cloned();
if let Some(depends_on) = dep_field {
field_deps.push(FieldDependency::DataDependency {
depends_on,
reason: format!("Argument dependency on field '{source_field}'"),
});
}
}
}
if !metadata.required_context.is_empty() {
field_deps.push(FieldDependency::ContextDependency {
required_context: metadata.required_context.clone(),
});
}
if field_deps.is_empty() {
field_deps.push(FieldDependency::Independent);
}
} else {
field_deps.push(FieldDependency::Independent);
}
dependencies.insert(field.clone(), field_deps);
}
dependencies
}
}
impl Default for DependencyGraph {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct FieldMetadata {
pub argument_sources: Vec<String>,
pub required_context: Vec<String>,
pub estimated_cost: f64,
pub cacheable: bool,
}
pub struct ParallelFieldResolver {
config: ParallelResolutionConfig,
concurrency_semaphore: Arc<Semaphore>,
active_resolutions: Arc<Gauge>,
total_resolutions: Arc<Counter>,
resolution_time: Arc<Histogram>,
parallelization_rate: Arc<Gauge>,
}
impl ParallelFieldResolver {
pub fn new(config: ParallelResolutionConfig) -> Self {
Self {
concurrency_semaphore: Arc::new(Semaphore::new(config.max_concurrency)),
active_resolutions: Arc::new(Gauge::new(
"parallel_resolver_active_resolutions".to_string(),
)),
total_resolutions: Arc::new(Counter::new(
"parallel_resolver_total_resolutions".to_string(),
)),
resolution_time: Arc::new(Histogram::new(
"parallel_resolver_resolution_time_ms".to_string(),
)),
parallelization_rate: Arc::new(Gauge::new(
"parallel_resolver_parallelization_rate".to_string(),
)),
config,
}
}
pub async fn resolve_fields<T: Clone + Send + Sync + 'static>(
&self,
tasks: Vec<FieldResolutionTask<T>>,
dependency_graph: Arc<DependencyGraph>,
) -> Result<Vec<FieldResolutionResult<T>>, anyhow::Error> {
if !self.config.enabled || tasks.len() < self.config.min_fields_for_parallel {
return self.resolve_sequential(tasks).await;
}
let start_time = Instant::now();
let total_fields = tasks.len();
self.active_resolutions.set(total_fields as f64);
for _ in 0..total_fields {
self.total_resolutions.inc();
}
let task_map: HashMap<FieldId, FieldResolutionTask<T>> = tasks
.into_iter()
.map(|task| (task.field_id.clone(), task))
.collect();
let mut results: Vec<FieldResolutionResult<T>> = Vec::new();
let mut join_set: JoinSet<Result<FieldResolutionResult<T>, anyhow::Error>> = JoinSet::new();
let mut completed_fields = HashSet::new();
while completed_fields.len() < total_fields {
let ready_fields = dependency_graph.get_ready_fields().await;
if ready_fields.is_empty() {
if let Some(result) = join_set.join_next().await {
let field_result =
result.map_err(|e| anyhow::anyhow!("Task join error: {e}"))??;
dependency_graph
.mark_resolved(field_result.field_id.clone())
.await;
completed_fields.insert(field_result.field_id.clone());
results.push(field_result);
} else if completed_fields.len() < total_fields {
return Err(anyhow!(
"Dependency deadlock detected: {}/{} fields completed",
completed_fields.len(),
total_fields
));
}
continue;
}
for field_id in ready_fields {
if completed_fields.contains(&field_id) {
continue;
}
if let Some(task) = task_map.get(&field_id) {
dependency_graph.mark_in_progress(field_id.clone()).await;
let field_id = task.field_id.clone();
let resolver = Arc::clone(&task.resolver);
let semaphore = Arc::clone(&self.concurrency_semaphore);
let timeout = self.config.field_timeout;
join_set.spawn(async move {
let _permit = semaphore
.acquire()
.await
.map_err(|e| anyhow::anyhow!("Semaphore error: {e}"))?;
let start = Instant::now();
let result = tokio::time::timeout(
timeout,
tokio::task::spawn_blocking(move || resolver()),
)
.await;
let execution_time = start.elapsed();
let result = match result {
Ok(Ok(Ok(value))) => Ok(value),
Ok(Ok(Err(e))) => Err(format!("Resolver error: {e}")),
Ok(Err(e)) => Err(format!("Task panic: {e}")),
Err(_) => Err(format!("Resolution timeout after {timeout:?}")),
};
Ok::<FieldResolutionResult<T>, anyhow::Error>(FieldResolutionResult {
field_id,
result,
execution_time,
resolved_at: Instant::now(),
})
});
}
}
while let Some(result) = join_set.try_join_next() {
let field_result =
result.map_err(|e| anyhow::anyhow!("Task join error: {e}"))??;
dependency_graph
.mark_resolved(field_result.field_id.clone())
.await;
completed_fields.insert(field_result.field_id.clone());
results.push(field_result);
}
}
while let Some(result) = join_set.join_next().await {
let field_result = result.map_err(|e| anyhow::anyhow!("Task join error: {e}"))??;
results.push(field_result);
}
let total_time = start_time.elapsed();
let sequential_time: Duration = results.iter().map(|r| r.execution_time).sum();
let parallelization = if total_time.as_secs_f64() > 0.0 {
sequential_time.as_secs_f64() / total_time.as_secs_f64()
} else {
1.0
};
self.parallelization_rate.set(parallelization);
self.active_resolutions.set(0.0);
Ok(results)
}
async fn resolve_sequential<T: Clone + Send + Sync + 'static>(
&self,
tasks: Vec<FieldResolutionTask<T>>,
) -> Result<Vec<FieldResolutionResult<T>>> {
let mut results = Vec::new();
for task in tasks {
let start = Instant::now();
let resolver = Arc::clone(&task.resolver);
let result = tokio::task::spawn_blocking(move || resolver())
.await
.map_err(|e| anyhow!("Task error: {e}"))?;
let execution_time = start.elapsed();
self.resolution_time
.observe(execution_time.as_millis() as f64);
results.push(FieldResolutionResult {
field_id: task.field_id,
result: result.map_err(|e| e.to_string()),
execution_time,
resolved_at: Instant::now(),
});
}
self.parallelization_rate.set(1.0); Ok(results)
}
pub fn get_metrics(&self) -> ParallelResolverMetrics {
ParallelResolverMetrics {
active_resolutions: self.active_resolutions.get() as usize,
total_resolutions: self.total_resolutions.get() as usize,
avg_resolution_time_ms: 0.0, parallelization_rate: self.parallelization_rate.get(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParallelResolverMetrics {
pub active_resolutions: usize,
pub total_resolutions: usize,
pub avg_resolution_time_ms: f64,
pub parallelization_rate: f64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_field_id_creation() {
let field = FieldId::new("User".to_string(), "name".to_string());
assert_eq!(field.parent_type, "User");
assert_eq!(field.field_name, "name");
assert_eq!(field.effective_name(), "name");
}
#[test]
fn test_field_id_with_alias() {
let field = FieldId::new("User".to_string(), "firstName".to_string())
.with_alias("name".to_string());
assert_eq!(field.effective_name(), "name");
}
#[test]
fn test_field_id_with_path() {
let field = FieldId::new("User".to_string(), "name".to_string())
.with_path(vec!["user".to_string(), "profile".to_string()]);
assert_eq!(field.path.len(), 2);
}
#[tokio::test]
async fn test_dependency_graph_creation() {
let graph = DependencyGraph::new();
assert!(graph.dependencies.is_empty());
}
#[tokio::test]
async fn test_dependency_graph_can_resolve_independent() {
let field1 = FieldId::new("User".to_string(), "name".to_string());
let field2 = FieldId::new("User".to_string(), "email".to_string());
let mut deps = HashMap::new();
deps.insert(field1.clone(), vec![FieldDependency::Independent]);
deps.insert(field2.clone(), vec![FieldDependency::Independent]);
let graph = DependencyGraph::new().with_dependencies(deps);
assert!(graph.can_resolve(&field1).await);
assert!(graph.can_resolve(&field2).await);
}
#[tokio::test]
async fn test_dependency_graph_data_dependency() {
let field1 = FieldId::new("User".to_string(), "id".to_string());
let field2 = FieldId::new("Post".to_string(), "posts".to_string());
let mut deps = HashMap::new();
deps.insert(field1.clone(), vec![FieldDependency::Independent]);
deps.insert(
field2.clone(),
vec![FieldDependency::DataDependency {
depends_on: field1.clone(),
reason: "Needs user ID".to_string(),
}],
);
let graph = DependencyGraph::new().with_dependencies(deps);
assert!(graph.can_resolve(&field1).await);
assert!(!graph.can_resolve(&field2).await);
graph.mark_resolved(field1.clone()).await;
assert!(graph.can_resolve(&field2).await);
}
#[tokio::test]
async fn test_dependency_analysis() {
let field1 = FieldId::new("Query".to_string(), "userId".to_string());
let field2 = FieldId::new("Query".to_string(), "userPosts".to_string());
let mut metadata = HashMap::new();
metadata.insert(
field1.clone(),
FieldMetadata {
argument_sources: vec![],
required_context: vec![],
estimated_cost: 1.0,
cacheable: true,
},
);
metadata.insert(
field2.clone(),
FieldMetadata {
argument_sources: vec!["$userId".to_string()],
required_context: vec![],
estimated_cost: 5.0,
cacheable: false,
},
);
let fields = vec![field1.clone(), field2.clone()];
let deps = DependencyGraph::analyze_dependencies(&fields, &metadata);
assert!(matches!(
deps.get(&field1).expect("should succeed")[0],
FieldDependency::Independent
));
assert!(matches!(
deps.get(&field2).expect("should succeed")[0],
FieldDependency::DataDependency { .. }
));
}
#[tokio::test]
async fn test_parallel_resolver_creation() {
let config = ParallelResolutionConfig::default();
let resolver = ParallelFieldResolver::new(config);
let metrics = resolver.get_metrics();
assert_eq!(metrics.active_resolutions, 0);
}
#[tokio::test]
async fn test_sequential_resolution() {
let config = ParallelResolutionConfig {
enabled: false,
..Default::default()
};
let resolver = ParallelFieldResolver::new(config);
let field1 = FieldId::new("Query".to_string(), "field1".to_string());
let field2 = FieldId::new("Query".to_string(), "field2".to_string());
let tasks = vec![
FieldResolutionTask {
field_id: field1.clone(),
resolver: Arc::new(|| Ok(42)),
dependencies: vec![FieldDependency::Independent],
estimated_cost: 1.0,
priority: 0,
},
FieldResolutionTask {
field_id: field2.clone(),
resolver: Arc::new(|| Ok(24)),
dependencies: vec![FieldDependency::Independent],
estimated_cost: 1.0,
priority: 0,
},
];
let graph = Arc::new(DependencyGraph::new());
let results = resolver
.resolve_fields(tasks, graph)
.await
.expect("should succeed");
assert_eq!(results.len(), 2);
assert_eq!(results[0].result.as_ref().expect("should succeed"), &42);
assert_eq!(results[1].result.as_ref().expect("should succeed"), &24);
}
#[tokio::test]
async fn test_parallel_resolution_independent_fields() {
let config = ParallelResolutionConfig {
enabled: true,
min_fields_for_parallel: 2,
..Default::default()
};
let resolver = ParallelFieldResolver::new(config);
let field1 = FieldId::new("Query".to_string(), "field1".to_string());
let field2 = FieldId::new("Query".to_string(), "field2".to_string());
let field3 = FieldId::new("Query".to_string(), "field3".to_string());
let mut deps = HashMap::new();
deps.insert(field1.clone(), vec![FieldDependency::Independent]);
deps.insert(field2.clone(), vec![FieldDependency::Independent]);
deps.insert(field3.clone(), vec![FieldDependency::Independent]);
let tasks = vec![
FieldResolutionTask {
field_id: field1.clone(),
resolver: Arc::new(|| {
std::thread::sleep(Duration::from_millis(10));
Ok(1)
}),
dependencies: vec![FieldDependency::Independent],
estimated_cost: 1.0,
priority: 0,
},
FieldResolutionTask {
field_id: field2.clone(),
resolver: Arc::new(|| {
std::thread::sleep(Duration::from_millis(10));
Ok(2)
}),
dependencies: vec![FieldDependency::Independent],
estimated_cost: 1.0,
priority: 0,
},
FieldResolutionTask {
field_id: field3.clone(),
resolver: Arc::new(|| {
std::thread::sleep(Duration::from_millis(10));
Ok(3)
}),
dependencies: vec![FieldDependency::Independent],
estimated_cost: 1.0,
priority: 0,
},
];
let graph = Arc::new(DependencyGraph::new().with_dependencies(deps));
let start = Instant::now();
let results = resolver
.resolve_fields(tasks, graph)
.await
.expect("should succeed");
let elapsed = start.elapsed();
assert_eq!(results.len(), 3);
assert!(
elapsed < Duration::from_millis(500),
"Elapsed time {:?} should be less than 500ms for parallel execution",
elapsed
);
let metrics = resolver.get_metrics();
assert!(metrics.parallelization_rate > 1.0);
}
#[tokio::test]
async fn test_parallel_resolution_with_dependencies() {
let config = ParallelResolutionConfig {
enabled: true,
min_fields_for_parallel: 2,
..Default::default()
};
let resolver = ParallelFieldResolver::new(config);
let field1 = FieldId::new("Query".to_string(), "field1".to_string());
let field2 = FieldId::new("Query".to_string(), "field2".to_string());
let field3 = FieldId::new("Query".to_string(), "field3".to_string());
let mut deps = HashMap::new();
deps.insert(field1.clone(), vec![FieldDependency::Independent]);
deps.insert(
field2.clone(),
vec![FieldDependency::DataDependency {
depends_on: field1.clone(),
reason: "Depends on field1".to_string(),
}],
);
deps.insert(
field3.clone(),
vec![FieldDependency::DataDependency {
depends_on: field1.clone(),
reason: "Depends on field1".to_string(),
}],
);
let tasks = vec![
FieldResolutionTask {
field_id: field1.clone(),
resolver: Arc::new(|| {
std::thread::sleep(Duration::from_millis(10));
Ok(1)
}),
dependencies: deps.get(&field1).expect("should succeed").clone(),
estimated_cost: 1.0,
priority: 0,
},
FieldResolutionTask {
field_id: field2.clone(),
resolver: Arc::new(|| {
std::thread::sleep(Duration::from_millis(10));
Ok(2)
}),
dependencies: deps.get(&field2).expect("should succeed").clone(),
estimated_cost: 1.0,
priority: 0,
},
FieldResolutionTask {
field_id: field3.clone(),
resolver: Arc::new(|| {
std::thread::sleep(Duration::from_millis(10));
Ok(3)
}),
dependencies: deps.get(&field3).expect("should succeed").clone(),
estimated_cost: 1.0,
priority: 0,
},
];
let graph = Arc::new(DependencyGraph::new().with_dependencies(deps));
let results = resolver
.resolve_fields(tasks, graph)
.await
.expect("should succeed");
assert_eq!(results.len(), 3);
let field1_result = results
.iter()
.find(|r| r.field_id == field1)
.expect("should succeed");
let field2_result = results
.iter()
.find(|r| r.field_id == field2)
.expect("should succeed");
let field3_result = results
.iter()
.find(|r| r.field_id == field3)
.expect("should succeed");
assert!(field1_result.resolved_at <= field2_result.resolved_at);
assert!(field1_result.resolved_at <= field3_result.resolved_at);
}
#[tokio::test]
async fn test_resolution_timeout() {
let config = ParallelResolutionConfig {
enabled: true,
min_fields_for_parallel: 1,
field_timeout: Duration::from_millis(50),
..Default::default()
};
let resolver = ParallelFieldResolver::new(config);
let field1 = FieldId::new("Query".to_string(), "slow_field".to_string());
let mut deps = HashMap::new();
deps.insert(field1.clone(), vec![FieldDependency::Independent]);
let tasks = vec![FieldResolutionTask {
field_id: field1.clone(),
resolver: Arc::new(|| {
std::thread::sleep(Duration::from_millis(200));
Ok(42)
}),
dependencies: vec![FieldDependency::Independent],
estimated_cost: 1.0,
priority: 0,
}];
let graph = Arc::new(DependencyGraph::new().with_dependencies(deps));
let results = resolver
.resolve_fields(tasks, graph)
.await
.expect("should succeed");
assert_eq!(results.len(), 1);
assert!(results[0].result.is_err());
assert!(results[0].result.as_ref().unwrap_err().contains("timeout"));
}
#[test]
fn test_config_defaults() {
let config = ParallelResolutionConfig::default();
assert!(config.enabled);
assert!(config.adaptive_concurrency);
assert!(config.enable_work_stealing);
assert!(config.min_fields_for_parallel >= 2);
}
}